Ejemplo n.º 1
0
def export_model(args, sess, training_saver):
    with tf.Graph().as_default() as inference_graph:
        metadata = features.FeatureMetadata.get_metadata(args.metadata_path)
        placeholder, inputs, _, keys = taxifare.create_inputs(
            metadata, None, HYPERPARAMS)
        output = taxifare.inference(inputs, metadata, HYPERPARAMS)

        inference_saver = tf.train.Saver()

        # Mark the inputs and the outputs
        tf.add_to_collection('inputs',
                             json.dumps({'examples': placeholder.name}))
        tf.add_to_collection(
            'outputs', json.dumps({
                'key': keys.name,
                'score': output.name
            }))

        model_dir = os.path.join(args.output_path, EXPORT_SUBDIRECTORY)

        # We need to save the variables from the training session, but we need
        # to serialize the serving graph.

        # Serialize the graph (MetaGraphDef)
        inference_saver.export_meta_graph(
            filename=os.path.join(model_dir, 'export.meta'))

        # Save the variables. Don't write the MetaGraphDef, because that is
        # actually the training graph.
        training_saver.save(sess,
                            os.path.join(model_dir, 'export'),
                            write_meta_graph=False)
Ejemplo n.º 2
0
def evaluate(args):
    """Run one round of evaluation, yielding rmse."""

    eval_data = args.eval_data_paths

    with tf.Graph().as_default() as g:
        metadata = features.FeatureMetadata.get_metadata(args.metadata_path)

        _, examples = taxifare.read_examples(eval_data,
                                             HYPERPARAMS['batch_size'],
                                             shuffle=False,
                                             num_epochs=1)

        # Generate placeholders for the examples.
        placeholder, inputs, targets, _ = (taxifare.create_inputs(
            metadata, examples, HYPERPARAMS))

        # Build a Graph that computes predictions from the inference model.
        output = taxifare.inference(inputs, metadata, HYPERPARAMS)

        # Add to the Graph the Ops for loss calculation.
        loss = taxifare.loss(output, targets)

        # Add the Op to compute rmse.
        rmse_op, eval_op = metric_ops.streaming_root_mean_squared_error(
            output, targets)

        # The global step is useful for summaries.
        with tf.name_scope('train'):
            global_step = tf.Variable(0, name='global_step', trainable=False)

        tf.scalar_summary('rmse', rmse_op)
        tf.scalar_summary('training/hptuning/metric', rmse_op)
        summary = tf.merge_all_summaries(
        )  # make sure all scalar summaries are produced

        saver = tf.train.Saver()

    # Setting num_eval_batches isn't strictly necessary, as the file reader does
    # at most one epoch.
    num_eval_batches = float(EVAL_SET_SIZE) // HYPERPARAMS['batch_size']
    summary_writer = tf.train.SummaryWriter(
        os.path.join(args.output_path, 'eval'))
    sv = tf.train.Supervisor(graph=g,
                             logdir=os.path.join(args.output_path, 'eval'),
                             summary_op=summary,
                             summary_writer=summary_writer,
                             global_step=None,
                             saver=saver)

    step = 0
    while step < args.max_steps:
        last_checkpoint = tf.train.latest_checkpoint(
            os.path.join(args.output_path, 'logdir'))
        with sv.managed_session(master='',
                                start_standard_services=False) as session:
            sv.start_queue_runners(session)
            sv.saver.restore(session, last_checkpoint)
            rmse = tf_evaluation(session,
                                 max_num_evals=num_eval_batches,
                                 eval_op=eval_op,
                                 final_op=rmse_op,
                                 summary_op=summary,
                                 summary_writer=summary_writer,
                                 global_step=global_step)

            step = tf.train.global_step(session, global_step)
            yield rmse
Ejemplo n.º 3
0
def run_training(args, target, is_chief, device_fn):
    """Train Census for a number of steps."""
    # Get the sets of examples and targets for training, validation, and
    # test on Census.
    training_data = args.train_data_paths

    if is_chief:
        # A generator over accuracies. Each call to next(accuracies) forces an
        # evaluation of the model.
        accuracies = evaluate(args)

    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default() as graph:
        # Assigns ops to the local worker by default.
        with tf.device(device_fn):

            metadata = features.FeatureMetadata.get_metadata(
                args.metadata_path)

            _, train_examples = taxifare.read_examples(
                training_data, HYPERPARAMS['batch_size'], shuffle=False)

            # Generate placeholders for the examples.
            placeholder, inputs, targets, _ = (taxifare.create_inputs(
                metadata, train_examples, HYPERPARAMS))

            # Build a Graph that computes predictions from the inference model.
            output = taxifare.inference(inputs, metadata, HYPERPARAMS)

            # Add to the Graph the Ops for loss calculation.
            loss = taxifare.loss(output, targets)

            # Add to the Graph the Ops that calculate and apply gradients.
            train_op, global_step = taxifare.training(
                loss, HYPERPARAMS['learning_rate'])

            # Build the summary operation based on the TF collection of Summaries.
            summary_op = tf.merge_all_summaries()

            # Add the variable initializer Op.
            init_op = tf.initialize_all_variables()

            # Create a saver for writing training checkpoints.
            saver = tf.train.Saver()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.train.SummaryWriter(
                os.path.join(args.output_path, 'summaries'), graph)

            # Create a "supervisor", which oversees the training process.
            sv = tf.train.Supervisor(is_chief=is_chief,
                                     logdir=os.path.join(
                                         args.output_path, 'logdir'),
                                     init_op=init_op,
                                     saver=saver,
                                     summary_op=None,
                                     global_step=global_step,
                                     save_model_secs=60)

            # The supervisor takes care of session initialization, restoring from
            # a checkpoint, and closing when done or an error occurs.
            logging.info('Starting the loop.')
            with sv.managed_session(target) as sess:
                start_time = time.time()
                last_save = start_time

                # Loop until the supervisor shuts down or max_steps have completed.
                step = 0
                while not sv.should_stop() and step < args.max_steps:
                    start_time = time.time()

                    # Run one step of the model.  The return values are the activations
                    # from the `train_op` (which is discarded) and the `loss` Op.  To
                    # inspect the values of your Ops or variables, you may include them
                    # in the list passed to sess.run() and the value tensors will be
                    # returned in the tuple from the call.
                    _, step, loss_value = sess.run(
                        [train_op, global_step, loss])

                    duration = time.time() - start_time
                    if is_chief and time.time(
                    ) - last_save > EVAL_INTERVAL_SECS:
                        last_save = time.time()
                        saver.save(sess, sv.save_path, global_step)
                        rmse = next(accuracies)
                        logging.info('Eval, step %d: rmse = %0.3f', step, rmse)

                    # Write the summaries and log an overview fairly often.
                    if step % 100 == 0 and is_chief:
                        # Log status.
                        logging.info('Step %d: loss = %.2f (%.3f sec)', step,
                                     loss_value, duration)

                        # Update the events file.
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                if is_chief:
                    # Force a save at the end of our loop.
                    sv.saver.save(sess,
                                  sv.save_path,
                                  global_step=global_step,
                                  write_meta_graph=False)

                    logging.info('Final rmse after %d steps = %0.3f', step,
                                 next(accuracies))

                    # Save the model for inference
                    export_model(args, sess, sv.saver)

            # Ask for all the services to stop.
            sv.stop()
            logging.info('Done training.')