def evaluate(args): """Run one round of evaluation, yielding rmse.""" eval_data = args.eval_data_paths with tf.Graph().as_default() as g: metadata = features.FeatureMetadata.get_metadata(args.metadata_path) _, examples = taxifare.read_examples(eval_data, HYPERPARAMS['batch_size'], shuffle=False, num_epochs=1) # Generate placeholders for the examples. placeholder, inputs, targets, _ = (taxifare.create_inputs( metadata, examples, HYPERPARAMS)) # Build a Graph that computes predictions from the inference model. output = taxifare.inference(inputs, metadata, HYPERPARAMS) # Add to the Graph the Ops for loss calculation. loss = taxifare.loss(output, targets) # Add the Op to compute rmse. rmse_op, eval_op = metric_ops.streaming_root_mean_squared_error( output, targets) # The global step is useful for summaries. with tf.name_scope('train'): global_step = tf.Variable(0, name='global_step', trainable=False) tf.scalar_summary('rmse', rmse_op) tf.scalar_summary('training/hptuning/metric', rmse_op) summary = tf.merge_all_summaries( ) # make sure all scalar summaries are produced saver = tf.train.Saver() # Setting num_eval_batches isn't strictly necessary, as the file reader does # at most one epoch. num_eval_batches = float(EVAL_SET_SIZE) // HYPERPARAMS['batch_size'] summary_writer = tf.train.SummaryWriter( os.path.join(args.output_path, 'eval')) sv = tf.train.Supervisor(graph=g, logdir=os.path.join(args.output_path, 'eval'), summary_op=summary, summary_writer=summary_writer, global_step=None, saver=saver) step = 0 while step < args.max_steps: last_checkpoint = tf.train.latest_checkpoint( os.path.join(args.output_path, 'logdir')) with sv.managed_session(master='', start_standard_services=False) as session: sv.start_queue_runners(session) sv.saver.restore(session, last_checkpoint) rmse = tf_evaluation(session, max_num_evals=num_eval_batches, eval_op=eval_op, final_op=rmse_op, summary_op=summary, summary_writer=summary_writer, global_step=global_step) step = tf.train.global_step(session, global_step) yield rmse
def run_training(args, target, is_chief, device_fn): """Train Census for a number of steps.""" # Get the sets of examples and targets for training, validation, and # test on Census. training_data = args.train_data_paths if is_chief: # A generator over accuracies. Each call to next(accuracies) forces an # evaluation of the model. accuracies = evaluate(args) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default() as graph: # Assigns ops to the local worker by default. with tf.device(device_fn): metadata = features.FeatureMetadata.get_metadata( args.metadata_path) _, train_examples = taxifare.read_examples( training_data, HYPERPARAMS['batch_size'], shuffle=False) # Generate placeholders for the examples. placeholder, inputs, targets, _ = (taxifare.create_inputs( metadata, train_examples, HYPERPARAMS)) # Build a Graph that computes predictions from the inference model. output = taxifare.inference(inputs, metadata, HYPERPARAMS) # Add to the Graph the Ops for loss calculation. loss = taxifare.loss(output, targets) # Add to the Graph the Ops that calculate and apply gradients. train_op, global_step = taxifare.training( loss, HYPERPARAMS['learning_rate']) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() # Add the variable initializer Op. init_op = tf.initialize_all_variables() # Create a saver for writing training checkpoints. saver = tf.train.Saver() # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.train.SummaryWriter( os.path.join(args.output_path, 'summaries'), graph) # Create a "supervisor", which oversees the training process. sv = tf.train.Supervisor(is_chief=is_chief, logdir=os.path.join( args.output_path, 'logdir'), init_op=init_op, saver=saver, summary_op=None, global_step=global_step, save_model_secs=60) # The supervisor takes care of session initialization, restoring from # a checkpoint, and closing when done or an error occurs. logging.info('Starting the loop.') with sv.managed_session(target) as sess: start_time = time.time() last_save = start_time # Loop until the supervisor shuts down or max_steps have completed. step = 0 while not sv.should_stop() and step < args.max_steps: start_time = time.time() # Run one step of the model. The return values are the activations # from the `train_op` (which is discarded) and the `loss` Op. To # inspect the values of your Ops or variables, you may include them # in the list passed to sess.run() and the value tensors will be # returned in the tuple from the call. _, step, loss_value = sess.run( [train_op, global_step, loss]) duration = time.time() - start_time if is_chief and time.time( ) - last_save > EVAL_INTERVAL_SECS: last_save = time.time() saver.save(sess, sv.save_path, global_step) rmse = next(accuracies) logging.info('Eval, step %d: rmse = %0.3f', step, rmse) # Write the summaries and log an overview fairly often. if step % 100 == 0 and is_chief: # Log status. logging.info('Step %d: loss = %.2f (%.3f sec)', step, loss_value, duration) # Update the events file. summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) summary_writer.flush() if is_chief: # Force a save at the end of our loop. sv.saver.save(sess, sv.save_path, global_step=global_step, write_meta_graph=False) logging.info('Final rmse after %d steps = %0.3f', step, next(accuracies)) # Save the model for inference export_model(args, sess, sv.saver) # Ask for all the services to stop. sv.stop() logging.info('Done training.')