def run_training(): """Train MNIST for a number of steps.""" # Get the sets of images and labels for training, validation, and # test on MNIST. data_sets = input_data.read_data_sets(tempfile.mkdtemp(), FLAGS.fake_data) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels and mark as input. placeholders = placeholder_inputs() keys_placeholder, images_placeholder, labels_placeholder = placeholders inputs = {'key': keys_placeholder.name, 'image': images_placeholder.name} tf.add_to_collection('inputs', json.dumps(inputs)) # Build a Graph that computes predictions from the inference model. logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the Ops for loss calculation. loss = mnist.loss(logits, labels_placeholder) # To be able to extract the id, we need to add the identity function. keys = tf.identity(keys_placeholder) # The prediction will be the index in logits with the highest score. # We also use a softmax operation to produce a probability distribution # over all possible digits. prediction = tf.argmax(logits, 1) scores = tf.nn.softmax(logits) # Mark the outputs. outputs = {'key': keys.name, 'prediction': prediction.name, 'scores': scores.name} tf.add_to_collection('outputs', json.dumps(outputs)) # Add to the Graph the Ops that calculate and apply gradients. train_op = mnist.training(loss, FLAGS.learning_rate) # Add the Op to compare the logits to the labels during evaluation. eval_correct = mnist.evaluation(logits, labels_placeholder) # Build the summary operation based on the TF collection of Summaries. # TODO(b/33420312): remove the if once 0.12 is fully rolled out to prod. if tf.__version__ < '0.12': summary_op = tf.merge_all_summaries() else: summary_op = tf.contrib.deprecated.merge_all_summaries() # Add the variable initializer Op. init = tf.initialize_all_variables() # Create a saver for writing training checkpoints. saver = tf.train.Saver() # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) # And then after everything is built: # Run the Op to initialize the variables. sess.run(init) # Start the training loop. for step in xrange(FLAGS.max_steps): start_time = time.time() # Fill a feed dictionary with the actual set of images and labels # for this particular training step. feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) # Run one step of the model. The return values are the activations # from the `train_op` (which is discarded) and the `loss` Op. To # inspect the values of your Ops or variables, you may include them # in the list passed to sess.run() and the value tensors will be # returned in the tuple from the call. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 100 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) # Evaluate against the validation set. print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) # Evaluate against the test set. print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test) # Export the model so that it can be loaded and used later for predictions. file_io.create_dir(FLAGS.model_dir) saver.save(sess, os.path.join(FLAGS.model_dir, 'export'))
def run_training(): """Train MNIST for a number of steps.""" # Get the sets of images and labels for training, validation, and # test on MNIST. If input_path is specified, download the data from GCS to # the folder expected by read_data_sets. data_dir = tempfile.mkdtemp() if FLAGS.input_path: files = [ os.path.join(FLAGS.input_path, file_name) for file_name in INPUT_FILES ] subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [data_dir]) data_sets = input_data.read_data_sets(data_dir, FLAGS.fake_data) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels and mark as input. placeholders = placeholder_inputs() keys_placeholder, images_placeholder, labels_placeholder = placeholders # Build a Graph that computes predictions from the inference model. logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the Ops for loss calculation. loss = mnist.loss(logits, labels_placeholder) # To be able to extract the id, we need to add the identity function. keys = tf.identity(keys_placeholder) # The prediction will be the index in logits with the highest score. # We also use a softmax operation to produce a probability distribution # over all possible digits. prediction = tf.argmax(logits, 1) scores = tf.nn.softmax(logits) # Add to the Graph the Ops that calculate and apply gradients. train_op = mnist.training(loss, FLAGS.learning_rate) # Add the Op to compare the logits to the labels during evaluation. eval_correct = mnist.evaluation(logits, labels_placeholder) # Build the summary operation based on the TF collection of Summaries. # Remove this if once Tensorflow 0.12 is standard. try: summary_op = tf.contrib.deprecated.merge_all_summaries() except AttributeError: summary_op = tf.merge_all_summaries() # Add the variable initializer Op. init = tf.initialize_all_variables() # Create a saver for writing legacy training checkpoints. saver = tf.train.Saver() # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. # Remove this if once Tensorflow 0.12 is standard. try: summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) except AttributeError: summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) # And then after everything is built: # Run the Op to initialize the variables. sess.run(init) # Start the training loop. for step in xrange(FLAGS.max_steps): start_time = time.time() # Fill a feed dictionary with the actual set of images and labels # for this particular training step. feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) # Run one step of the model. The return values are the activations # from the `train_op` (which is discarded) and the `loss` Op. To # inspect the values of your Ops or variables, you may include them # in the list passed to sess.run() and the value tensors will be # returned in the tuple from the call. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 100 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) # Evaluate against the validation set. print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) # Evaluate against the test set. print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test) file_io.create_dir(FLAGS.model_dir) # Create a saver for writing SavedModel training checkpoints. saved_model_util.simple_save(sess, os.path.join(FLAGS.model_dir, 'saved_model'), inputs={ 'key': keys_placeholder, 'image': images_placeholder }, outputs={ 'key': keys, 'prediction': prediction, 'scores': scores }) logging.debug('Saved model path %s', os.path.join(FLAGS.model_dir, 'saved_model'))
def run_training(): """Train MNIST for a number of steps.""" # Get the sets of images and labels for training, validation, and # test on MNIST. data_sets = input_data.read_data_sets(tempfile.mkdtemp(), FLAGS.fake_data) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels and mark as input. placeholders = placeholder_inputs() keys_placeholder, images_placeholder, labels_placeholder = placeholders inputs = {'key': keys_placeholder.name, 'image': images_placeholder.name} tf.add_to_collection('inputs', json.dumps(inputs)) # Build a Graph that computes predictions from the inference model. logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the Ops for loss calculation. loss = mnist.loss(logits, labels_placeholder) # To be able to extract the id, we need to add the identity function. keys = tf.identity(keys_placeholder) # The prediction will be the index in logits with the highest score. # We also use a softmax operation to produce a probability distribution # over all possible digits. prediction = tf.argmax(logits, 1) scores = tf.nn.softmax(logits) # Mark the outputs. outputs = {'key': keys.name, 'prediction': prediction.name, 'scores': scores.name} tf.add_to_collection('outputs', json.dumps(outputs)) # Add to the Graph the Ops that calculate and apply gradients. train_op = mnist.training(loss, FLAGS.learning_rate) # Add the Op to compare the logits to the labels during evaluation. eval_correct = mnist.evaluation(logits, labels_placeholder) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() # Add the variable initializer Op. init = tf.initialize_all_variables() # Create a saver for writing training checkpoints. saver = tf.train.Saver() # Create a session for running Ops on the Graph. sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) # And then after everything is built: # Run the Op to initialize the variables. sess.run(init) # Start the training loop. for step in xrange(FLAGS.max_steps): start_time = time.time() # Fill a feed dictionary with the actual set of images and labels # for this particular training step. feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) # Run one step of the model. The return values are the activations # from the `train_op` (which is discarded) and the `loss` Op. To # inspect the values of your Ops or variables, you may include them # in the list passed to sess.run() and the value tensors will be # returned in the tuple from the call. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 100 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) # Evaluate against the validation set. print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) # Evaluate against the test set. print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test) # Export the model so that it can be loaded and used later for predictions. file_io.create_dir(FLAGS.model_dir) saver.save(sess, os.path.join(FLAGS.model_dir, 'export'))