def feed_dict(mgr, batch_size):
     tmp = TFNode.next_batch(mgr, batch_size)
     # extract TFRecords, since tmp array is [(TFRecord, None)]
     tfrecords = []
     for elem in tmp:
         tfrecords.append(str(elem[0]))
     return tfrecords
Exemplo n.º 2
0
    def feed_dict():
        # Get a batch of examples from spark data feeder job
        batch = TFNode.next_batch(ctx.mgr, 100)

        # Convert from [(images, labels)] to two numpy arrays of the proper type
        images = []
        labels = []
        for item in batch:
            images.append(item[0])
            labels.append(item[1])
        xs = numpy.array(images)
        xs = xs.astype(numpy.float32)
        xs = xs / 255.0
        ys = numpy.array(labels)
        ys = ys.astype(numpy.uint8)
        return (xs, ys)
Exemplo n.º 3
0
def main_fun(argv, ctx):
    import tensorflow as tf
    from inception import inception_eval
    from inception.imagenet_data import ImagenetData

    print("argv:", argv)
    sys.argv = argv

    FLAGS = tf.app.flags.FLAGS
    FLAGS._parse_flags()
    print("FLAGS:", FLAGS.__dict__['__flags'])

    dataset = ImagenetData(subset=FLAGS.subset)
    assert dataset.data_files()
    if tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.DeleteRecursively(FLAGS.eval_dir)
    tf.gfile.MakeDirs(FLAGS.eval_dir)

    cluster_spec, server = TFNode.start_cluster_server(ctx, 1, FLAGS.rdma)

    inception_eval.evaluate(dataset)
def main_fun(argv, ctx):

    # extract node metadata from ctx
    worker_num = ctx.worker_num
    job_name = ctx.job_name
    task_index = ctx.task_index

    assert job_name in ['ps', 'worker'], 'job_name must be ps or worker'

    from inception import inception_distributed_train
    from inception.imagenet_data import ImagenetData
    import tensorflow as tf

    # instantiate FLAGS on workers using argv from driver and add job_name and task_id
    print("argv:", argv)
    sys.argv = argv

    FLAGS = tf.app.flags.FLAGS
    FLAGS.job_name = job_name
    FLAGS.task_id = task_index
    print("FLAGS:", FLAGS.__dict__['__flags'])

    # Get TF cluster and server instances
    cluster_spec, server = TFNode.start_cluster_server(ctx, FLAGS.num_gpus,
                                                       FLAGS.rdma)

    if FLAGS.job_name == 'ps':
        # `ps` jobs wait for incoming connections from the workers.
        server.join()
    else:
        # `worker` jobs will actually do the work.
        dataset = ImagenetData(subset=FLAGS.subset)
        assert dataset.data_files()
        # Only the chief checks for or creates train_dir.
        if FLAGS.task_id == 0:
            if not tf.gfile.Exists(FLAGS.train_dir):
                tf.gfile.MakeDirs(FLAGS.train_dir)
        inception_distributed_train.train(server.target, dataset, cluster_spec,
                                          ctx)
Exemplo n.º 5
0
def main_fun(argv, ctx):
    import math
    import tensorflow as tf

    from datasets import dataset_factory
    from nets import nets_factory
    from preprocessing import preprocessing_factory

    sys.argv = argv

    slim = tf.contrib.slim

    tf.app.flags.DEFINE_integer('batch_size', 100,
                                'The number of samples in each batch.')

    tf.app.flags.DEFINE_integer(
        'max_num_batches', None,
        'Max number of batches to evaluate by default use all.')

    tf.app.flags.DEFINE_string('master', '',
                               'The address of the TensorFlow master to use.')

    tf.app.flags.DEFINE_string(
        'checkpoint_path', '/tmp/tfmodel/',
        'The directory where the model was written to or an absolute path to a '
        'checkpoint file.')

    tf.app.flags.DEFINE_string('eval_dir', '/tmp/tfmodel/',
                               'Directory where the results are saved to.')

    tf.app.flags.DEFINE_integer(
        'num_preprocessing_threads', 4,
        'The number of threads used to create the batches.')

    tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
                               'The name of the dataset to load.')

    tf.app.flags.DEFINE_string('dataset_split_name', 'test',
                               'The name of the train/test split.')

    tf.app.flags.DEFINE_string(
        'dataset_dir', None,
        'The directory where the dataset files are stored.')

    tf.app.flags.DEFINE_integer(
        'labels_offset', 0,
        'An offset for the labels in the dataset. This flag is primarily used to '
        'evaluate the VGG and ResNet architectures which do not use a background '
        'class for the ImageNet dataset.')

    tf.app.flags.DEFINE_string('model_name', 'inception_v3',
                               'The name of the architecture to evaluate.')

    tf.app.flags.DEFINE_string(
        'preprocessing_name', None,
        'The name of the preprocessing to use. If left '
        'as `None`, then the model_name flag is used.')

    tf.app.flags.DEFINE_float(
        'moving_average_decay', None,
        'The decay to use for the moving average.'
        'If left as None, then moving averages are not used.')

    tf.app.flags.DEFINE_integer('eval_image_size', None, 'Eval image size')

    FLAGS = tf.app.flags.FLAGS

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    cluster_spec, server = TFNode.start_cluster_server(ctx)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=2 * FLAGS.batch_size,
            common_queue_min=FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        ####################
        # Define the model #
        ####################
        logits, _ = network_fn(images)

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        predictions = tf.argmax(logits, 1)
        labels = tf.squeeze(labels)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
            'Recall@5':
            slim.metrics.streaming_recall_at_k(logits, labels, 5),
        })

        # Print the summaries to screen.
        for name, value in names_to_values.iteritems():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # TODO(sguada) use num_epochs=1
        if FLAGS.max_num_batches:
            num_batches = FLAGS.max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples /
                                    float(FLAGS.batch_size))

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path

        tf.logging.info('Evaluating %s' % checkpoint_path)

        slim.evaluation.evaluate_once(
            master=FLAGS.master,
            checkpoint_path=checkpoint_path,
            logdir=FLAGS.eval_dir,
            num_evals=num_batches,
            eval_op=names_to_updates.values(),
            variables_to_restore=variables_to_restore)
Exemplo n.º 6
0
def map_fun(args, ctx):
    from com.yahoo.ml.tf import TFNode
    from datetime import datetime
    import math
    import numpy
    import tensorflow as tf
    import time

    worker_num = ctx.worker_num
    job_name = ctx.job_name
    task_index = ctx.task_index
    cluster_spec = ctx.cluster_spec

    IMAGE_PIXELS = 28

    # Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict)
    if job_name == "ps":
        time.sleep((worker_num + 1) * 5)

    # Parameters
    hidden_units = 128
    batch_size = 100

    # Get TF cluster and server instances
    cluster, server = TFNode.start_cluster_server(ctx, 1, args.rdma)

    def feed_dict():
        # Get a batch of examples from spark data feeder job
        batch = TFNode.next_batch(ctx.mgr, 100)

        # Convert from [(images, labels)] to two numpy arrays of the proper type
        images = []
        labels = []
        for item in batch:
            images.append(item[0])
            labels.append(item[1])
        xs = numpy.array(images)
        xs = xs.astype(numpy.float32)
        xs = xs / 255.0
        ys = numpy.array(labels)
        ys = ys.astype(numpy.uint8)
        return (xs, ys)

    if job_name == "ps":
        server.join()
    elif job_name == "worker":

        # Assigns ops to the local worker by default.
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % task_index,
                    cluster=cluster)):

            # Variables of the hidden layer
            hid_w = tf.Variable(tf.truncated_normal(
                [IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
                stddev=1.0 / IMAGE_PIXELS),
                                name="hid_w")
            hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")

            # Variables of the softmax layer
            sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
                                                   stddev=1.0 /
                                                   math.sqrt(hidden_units)),
                               name="sm_w")
            sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

            # Placeholders or QueueRunner/Readers for input data
            x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS],
                               name="x")
            y_ = tf.placeholder(tf.float32, [None, 10], name="y_")

            hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
            hid = tf.nn.relu(hid_lin)

            y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

            global_step = tf.Variable(0)

            loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
            train_op = tf.train.AdagradOptimizer(0.01).minimize(
                loss, global_step=global_step)

            # Test trained model
            label = tf.argmax(y_, 1, name="label")
            prediction = tf.argmax(y, 1, name="prediction")
            correct_prediction = tf.equal(prediction, label)
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
                                      name="accuracy")

            saver = tf.train.Saver()
            summary_op = tf.summary.merge_all()
            init_op = tf.global_variables_initializer()

        # Create a "supervisor", which oversees the training process and stores model state into HDFS
        logdir = TFNode.hdfs_path(ctx, args.model)
        print("tensorflow model path: {0}".format(logdir))
        summary_writer = tf.summary.FileWriter("tensorboard_%d" % (worker_num),
                                               graph=tf.get_default_graph())

        if args.mode == "train":
            sv = tf.train.Supervisor(is_chief=(task_index == 0),
                                     logdir=logdir,
                                     init_op=init_op,
                                     summary_op=summary_op,
                                     saver=saver,
                                     global_step=global_step,
                                     summary_writer=summary_writer,
                                     stop_grace_secs=300,
                                     save_model_secs=10)
        else:
            sv = tf.train.Supervisor(is_chief=(task_index == 0),
                                     logdir=logdir,
                                     saver=saver,
                                     global_step=global_step,
                                     stop_grace_secs=300,
                                     save_model_secs=0)

        # The supervisor takes care of session initialization, restoring from
        # a checkpoint, and closing when done or an error occurs.
        with sv.managed_session(server.target) as sess:
            print("{0} session ready".format(datetime.now().isoformat()))

            # Loop until the supervisor shuts down or 1000000 steps have completed.
            step = 0
            count = 0
            while not sv.should_stop() and step < args.steps:
                # Run a training step asynchronously.
                # See `tf.train.SyncReplicasOptimizer` for additional details on how to
                # perform *synchronous* training.

                # using feed_dict
                batch_xs, batch_ys = feed_dict()
                feed = {x: batch_xs, y_: batch_ys}

                if len(batch_xs) != batch_size:
                    print("done feeding")
                    break
                else:
                    if args.mode == "train":
                        _, step = sess.run([train_op, global_step],
                                           feed_dict=feed)
                        # print accuracy and save model checkpoint to HDFS every 100 steps
                        if (step % 100 == 0):
                            print("{0} step: {1} accuracy: {2}".format(
                                datetime.now().isoformat(), step,
                                sess.run(accuracy, {
                                    x: batch_xs,
                                    y_: batch_ys
                                })))
                    else:  # args.mode == "inference"
                        labels, preds, acc = sess.run(
                            [label, prediction, accuracy], feed_dict=feed)

                        results = [
                            "{0} Label: {1}, Prediction: {2}".format(
                                datetime.now().isoformat(), l, p)
                            for l, p in zip(labels, preds)
                        ]
                        TFNode.batch_results(ctx.mgr, results)
                        print("acc: {0}".format(acc))

            if sv.should_stop() or step >= args.steps:
                TFNode.terminate(ctx.mgr)

        # Ask for all the services to stop.
        print("{0} stopping supervisor".format(datetime.now().isoformat()))
        sv.stop()
Exemplo n.º 7
0
def map_fun(args, ctx):
  from com.yahoo.ml.tf import TFNode
  from datetime import datetime
  import getpass
  import math
  import numpy
  import os
  import signal
  import tensorflow as tf
  import time

  IMAGE_PIXELS=28
  worker_num = ctx.worker_num
  job_name = ctx.job_name
  task_index = ctx.task_index
  cluster_spec = ctx.cluster_spec
  num_workers = len(cluster_spec['worker'])

  # Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict)
  if job_name == "ps":
    time.sleep((worker_num + 1) * 5)

  # Parameters
  hidden_units = 128
  batch_size   = 100

  # Get TF cluster and server instances
  cluster, server = TFNode.start_cluster_server(ctx, 1, args.rdma)

  def read_csv_examples(image_dir, label_dir, batch_size=100, num_epochs=None, task_index=None, num_workers=None):
    print_log(worker_num, "num_epochs: {0}".format(num_epochs))
    # Setup queue of csv image filenames
    tf_record_pattern = os.path.join(image_dir, 'part-*')
    images = tf.gfile.Glob(tf_record_pattern)
    print_log(worker_num, "images: {0}".format(images))
    image_queue = tf.train.string_input_producer(images, shuffle=False, capacity=1000, num_epochs=num_epochs, name="image_queue")

    # Setup queue of csv label filenames
    tf_record_pattern = os.path.join(label_dir, 'part-*')
    labels = tf.gfile.Glob(tf_record_pattern)
    print_log(worker_num, "labels: {0}".format(labels))
    label_queue = tf.train.string_input_producer(labels, shuffle=False, capacity=1000, num_epochs=num_epochs, name="label_queue")

    # Setup reader for image queue
    img_reader = tf.TextLineReader(name="img_reader")
    _, img_csv = img_reader.read(image_queue)
    image_defaults = [ [1.0] for col in range(784) ]
    img = tf.pack(tf.decode_csv(img_csv, image_defaults))
    # Normalize values to [0,1]
    norm = tf.constant(255, dtype=tf.float32, shape=(784,))
    image = tf.div(img, norm)
    print_log(worker_num, "image: {0}".format(image))

    # Setup reader for label queue
    label_reader = tf.TextLineReader(name="label_reader")
    _, label_csv = label_reader.read(label_queue)
    label_defaults = [ [1.0] for col in range(10) ]
    label = tf.pack(tf.decode_csv(label_csv, label_defaults))
    print_log(worker_num, "label: {0}".format(label))

    # Return a batch of examples
    return tf.train.batch([image,label], batch_size, num_threads=args.readers, name="batch_csv")

  def read_tfr_examples(path, batch_size=100, num_epochs=None, task_index=None, num_workers=None):
    print_log(worker_num, "num_epochs: {0}".format(num_epochs))

    # Setup queue of TFRecord filenames
    tf_record_pattern = os.path.join(path, 'part-*')
    files = tf.gfile.Glob(tf_record_pattern)
    queue_name = "file_queue"

    # split input files across workers, if specified
    if task_index is not None and num_workers is not None:
      num_files = len(files)
      files = files[task_index:num_files:num_workers]
      queue_name = "file_queue_{0}".format(task_index)

    print_log(worker_num, "files: {0}".format(files))
    file_queue = tf.train.string_input_producer(files, shuffle=False, capacity=1000, num_epochs=num_epochs, name=queue_name)

    # Setup reader for examples
    reader = tf.TFRecordReader(name="reader")
    _, serialized = reader.read(file_queue)
    feature_def = {'label': tf.FixedLenFeature([10], tf.int64), 'image': tf.FixedLenFeature([784], tf.int64) }
    features = tf.parse_single_example(serialized, feature_def)
    norm = tf.constant(255, dtype=tf.float32, shape=(784,))
    image = tf.div(tf.to_float(features['image']), norm)
    print_log(worker_num, "image: {0}".format(image))
    label = tf.to_float(features['label'])
    print_log(worker_num, "label: {0}".format(label))

    # Return a batch of examples
    return tf.train.batch([image,label], batch_size, num_threads=args.readers, name="batch")

  if job_name == "ps":
    server.join()
  elif job_name == "worker":
    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % task_index,
        cluster=cluster)):

      # Variables of the hidden layer
      hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
                              stddev=1.0 / IMAGE_PIXELS), name="hid_w")
      hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")

      # Variables of the softmax layer
      sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
                              stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
      sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

      # Placeholders or QueueRunner/Readers for input data
      num_epochs = 1 if args.mode == "inference" else None if args.epochs == 0 else args.epochs
      index = task_index if args.mode == "inference" else None
      workers = num_workers if args.mode == "inference" else None

      if args.format == "csv":
        images = TFNode.hdfs_path(ctx, args.images)
        labels = TFNode.hdfs_path(ctx, args.labels)
        x, y_ = read_csv_examples(images, labels, 100, num_epochs, index, workers)
      elif args.format == "tfr":
        images = TFNode.hdfs_path(ctx, args.images)
        x, y_ = read_tfr_examples(images, 100, num_epochs, index, workers)
      else:
        raise("{0} format not supported for tf input mode".format(args.format))

      hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
      hid = tf.nn.relu(hid_lin)

      y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

      global_step = tf.Variable(0)

      loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)

      # Test trained model
      label = tf.argmax(y_, 1, name="label")
      prediction = tf.argmax(y, 1,name="prediction")
      correct_prediction = tf.equal(prediction, label)
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")

      saver = tf.train.Saver()
      summary_op = tf.summary.merge_all()
      init_op = tf.global_variables_initializer()

    # Create a "supervisor", which oversees the training process and stores model state into HDFS
    logdir = TFNode.hdfs_path(ctx, args.model)
    print("tensorflow model path: {0}".format(logdir))
    summary_writer = tf.summary.FileWriter("tensorboard_%d" %(worker_num), graph=tf.get_default_graph())

    if args.mode == "train":
      sv = tf.train.Supervisor(is_chief=(task_index == 0),
                               logdir=logdir,
                               init_op=init_op,
                               summary_op=summary_op,
                               saver=saver,
                               global_step=global_step,
                               summary_writer=summary_writer,
                               stop_grace_secs=300,
                               save_model_secs=10)
    else:
      sv = tf.train.Supervisor(is_chief=(task_index == 0),
                               logdir=logdir,
                               saver=saver,
                               global_step=global_step,
                               stop_grace_secs=300,
                               save_model_secs=0)
      output_dir = TFNode.hdfs_path(ctx, args.output)
      output_file = tf.gfile.Open("{0}/part-{1:05d}".format(output_dir, worker_num), mode='w')

    # The supervisor takes care of session initialization, restoring from
    # a checkpoint, and closing when done or an error occurs.
    with sv.managed_session(server.target) as sess:
      print("{0} session ready".format(datetime.now().isoformat()))

      # Loop until the supervisor shuts down or 1000000 steps have completed.
      step = 0
      count = 0
      while not sv.should_stop() and step < args.steps:
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.

        # using QueueRunners/Readers
        if args.mode == "train":
          if (step % 100 == 0):
            print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
          _, summary, step = sess.run([train_op, summary_op, global_step])
          summary_writer.add_summary(summary, step)
        else: # args.mode == "inference"
          labels, pred, acc = sess.run([label, prediction, accuracy])
          #print("label: {0}, pred: {1}".format(labels, pred))
          print("acc: {0}".format(acc))
          for i in range(len(labels)):
            count += 1
            output_file.write("{0} {1}\n".format(labels[i], pred[i]))
          print("count: {0}".format(count))

    if args.mode == "inference":
      output_file.close()

    # Ask for all the services to stop.
    print("{0} stopping supervisor".format(datetime.now().isoformat()))
    sv.stop()
Exemplo n.º 8
0
def main_fun(argv, ctx):
    import tensorflow as tf
    from tensorflow.python.ops import control_flow_ops
    from datasets import dataset_factory
    from deployment import model_deploy
    from nets import nets_factory
    from preprocessing import preprocessing_factory

    sys.argv = argv

    slim = tf.contrib.slim

    tf.app.flags.DEFINE_integer('num_gpus', '1',
                                'The number of GPUs to use per node')

    tf.app.flags.DEFINE_boolean('rdma', False, 'Whether to use rdma.')

    tf.app.flags.DEFINE_string('master', '',
                               'The address of the TensorFlow master to use.')

    tf.app.flags.DEFINE_string(
        'train_dir', '/tmp/tfmodel/',
        'Directory where checkpoints and event logs are written to.')

    tf.app.flags.DEFINE_integer('num_clones', 1,
                                'Number of model clones to deploy.')

    tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
                                'Use CPUs to deploy clones.')

    tf.app.flags.DEFINE_integer('worker_replicas', 1,
                                'Number of worker replicas.')

    tf.app.flags.DEFINE_integer(
        'num_ps_tasks', 0,
        'The number of parameter servers. If the value is 0, then the parameters '
        'are handled locally by the worker.')

    tf.app.flags.DEFINE_integer(
        'num_readers', 4,
        'The number of parallel readers that read data from the dataset.')

    tf.app.flags.DEFINE_integer(
        'num_preprocessing_threads', 4,
        'The number of threads used to create the batches.')

    tf.app.flags.DEFINE_integer('log_every_n_steps', 10,
                                'The frequency with which logs are print.')

    tf.app.flags.DEFINE_integer(
        'save_summaries_secs', 600,
        'The frequency with which summaries are saved, in seconds.')

    tf.app.flags.DEFINE_integer(
        'save_interval_secs', 600,
        'The frequency with which the model is saved, in seconds.')

    tf.app.flags.DEFINE_integer(
        'task', 0, 'Task id of the replica running the training.')

    ######################
    # Optimization Flags #
    ######################

    tf.app.flags.DEFINE_float('weight_decay', 0.00004,
                              'The weight decay on the model weights.')

    tf.app.flags.DEFINE_string(
        'optimizer', 'rmsprop',
        'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
        '"ftrl", "momentum", "sgd" or "rmsprop".')

    tf.app.flags.DEFINE_float('adadelta_rho', 0.95,
                              'The decay rate for adadelta.')

    tf.app.flags.DEFINE_float('adagrad_initial_accumulator_value', 0.1,
                              'Starting value for the AdaGrad accumulators.')

    tf.app.flags.DEFINE_float(
        'adam_beta1', 0.9,
        'The exponential decay rate for the 1st moment estimates.')

    tf.app.flags.DEFINE_float(
        'adam_beta2', 0.999,
        'The exponential decay rate for the 2nd moment estimates.')

    tf.app.flags.DEFINE_float('opt_epsilon', 1.0,
                              'Epsilon term for the optimizer.')

    tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
                              'The learning rate power.')

    tf.app.flags.DEFINE_float('ftrl_initial_accumulator_value', 0.1,
                              'Starting value for the FTRL accumulators.')

    tf.app.flags.DEFINE_float('ftrl_l1', 0.0,
                              'The FTRL l1 regularization strength.')

    tf.app.flags.DEFINE_float('ftrl_l2', 0.0,
                              'The FTRL l2 regularization strength.')

    tf.app.flags.DEFINE_float(
        'momentum', 0.9,
        'The momentum for the MomentumOptimizer and RMSPropOptimizer.')

    tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')

    tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')

    #######################
    # Learning Rate Flags #
    #######################

    tf.app.flags.DEFINE_string(
        'learning_rate_decay_type', 'exponential',
        'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
        ' or "polynomial"')

    tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

    tf.app.flags.DEFINE_float(
        'end_learning_rate', 0.0001,
        'The minimal end learning rate used by a polynomial decay learning rate.'
    )

    tf.app.flags.DEFINE_float('label_smoothing', 0.0,
                              'The amount of label smoothing.')

    tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.94,
                              'Learning rate decay factor.')

    tf.app.flags.DEFINE_float(
        'num_epochs_per_decay', 2.0,
        'Number of epochs after which learning rate decays.')

    tf.app.flags.DEFINE_bool(
        'sync_replicas', False,
        'Whether or not to synchronize the replicas during training.')

    tf.app.flags.DEFINE_integer(
        'replicas_to_aggregate', 1,
        'The Number of gradients to collect before updating params.')

    tf.app.flags.DEFINE_float(
        'moving_average_decay', None,
        'The decay to use for the moving average.'
        'If left as None, then moving averages are not used.')

    #######################
    # Dataset Flags #
    #######################

    tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
                               'The name of the dataset to load.')

    tf.app.flags.DEFINE_string('dataset_split_name', 'train',
                               'The name of the train/test split.')

    tf.app.flags.DEFINE_string(
        'dataset_dir', None,
        'The directory where the dataset files are stored.')

    tf.app.flags.DEFINE_integer(
        'labels_offset', 0,
        'An offset for the labels in the dataset. This flag is primarily used to '
        'evaluate the VGG and ResNet architectures which do not use a background '
        'class for the ImageNet dataset.')

    tf.app.flags.DEFINE_string('model_name', 'inception_v3',
                               'The name of the architecture to train.')

    tf.app.flags.DEFINE_string(
        'preprocessing_name', None,
        'The name of the preprocessing to use. If left '
        'as `None`, then the model_name flag is used.')

    tf.app.flags.DEFINE_integer('batch_size', 32,
                                'The number of samples in each batch.')

    tf.app.flags.DEFINE_integer('train_image_size', None, 'Train image size')

    tf.app.flags.DEFINE_integer('max_number_of_steps', None,
                                'The maximum number of training steps.')

    #####################
    # Fine-Tuning Flags #
    #####################

    tf.app.flags.DEFINE_string(
        'checkpoint_path', None,
        'The path to a checkpoint from which to fine-tune.')

    tf.app.flags.DEFINE_string(
        'checkpoint_exclude_scopes', None,
        'Comma-separated list of scopes of variables to exclude when restoring '
        'from a checkpoint.')

    tf.app.flags.DEFINE_string(
        'trainable_scopes', None,
        'Comma-separated list of scopes to filter the set of variables to train.'
        'By default, None would train all the variables.')

    tf.app.flags.DEFINE_boolean(
        'ignore_missing_vars', False,
        'When restoring a checkpoint would ignore missing variables.')

    FLAGS = tf.app.flags.FLAGS
    FLAGS.job_name = ctx.job_name
    FLAGS.task = ctx.task_index
    FLAGS.num_clones = FLAGS.num_gpus
    FLAGS.worker_replicas = len(ctx.cluster_spec['worker'])
    assert (FLAGS.num_ps_tasks == (len(ctx.cluster_spec['ps'])
                                   if 'ps' in ctx.cluster_spec else 0))

    def _configure_learning_rate(num_samples_per_epoch, global_step):
        """Configures the learning rate.

    Args:
      num_samples_per_epoch: The number of samples in each epoch of training.
      global_step: The global_step tensor.

    Returns:
      A `Tensor` representing the learning rate.

    Raises:
      ValueError: if
    """
        decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
                          FLAGS.num_epochs_per_decay)
        if FLAGS.sync_replicas:
            decay_steps /= FLAGS.replicas_to_aggregate

        if FLAGS.learning_rate_decay_type == 'exponential':
            return tf.train.exponential_decay(
                FLAGS.learning_rate,
                global_step,
                decay_steps,
                FLAGS.learning_rate_decay_factor,
                staircase=True,
                name='exponential_decay_learning_rate')
        elif FLAGS.learning_rate_decay_type == 'fixed':
            return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
        elif FLAGS.learning_rate_decay_type == 'polynomial':
            return tf.train.polynomial_decay(
                FLAGS.learning_rate,
                global_step,
                decay_steps,
                FLAGS.end_learning_rate,
                power=1.0,
                cycle=False,
                name='polynomial_decay_learning_rate')
        else:
            raise ValueError(
                'learning_rate_decay_type [%s] was not recognized',
                FLAGS.learning_rate_decay_type)

    def _configure_optimizer(learning_rate):
        """Configures the optimizer used for training.

    Args:
      learning_rate: A scalar or `Tensor` learning rate.

    Returns:
      An instance of an optimizer.

    Raises:
      ValueError: if FLAGS.optimizer is not recognized.
    """
        if FLAGS.optimizer == 'adadelta':
            optimizer = tf.train.AdadeltaOptimizer(learning_rate,
                                                   rho=FLAGS.adadelta_rho,
                                                   epsilon=FLAGS.opt_epsilon)
        elif FLAGS.optimizer == 'adagrad':
            optimizer = tf.train.AdagradOptimizer(
                learning_rate,
                initial_accumulator_value=FLAGS.
                adagrad_initial_accumulator_value)
        elif FLAGS.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate,
                                               beta1=FLAGS.adam_beta1,
                                               beta2=FLAGS.adam_beta2,
                                               epsilon=FLAGS.opt_epsilon)
        elif FLAGS.optimizer == 'ftrl':
            optimizer = tf.train.FtrlOptimizer(
                learning_rate,
                learning_rate_power=FLAGS.ftrl_learning_rate_power,
                initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
                l1_regularization_strength=FLAGS.ftrl_l1,
                l2_regularization_strength=FLAGS.ftrl_l2)
        elif FLAGS.optimizer == 'momentum':
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   momentum=FLAGS.momentum,
                                                   name='Momentum')
        elif FLAGS.optimizer == 'rmsprop':
            optimizer = tf.train.RMSPropOptimizer(
                learning_rate,
                decay=FLAGS.rmsprop_decay,
                momentum=FLAGS.rmsprop_momentum,
                epsilon=FLAGS.opt_epsilon)
        elif FLAGS.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        else:
            raise ValueError('Optimizer [%s] was not recognized',
                             FLAGS.optimizer)
        return optimizer

    def _add_variables_summaries(learning_rate):
        summaries = []
        for variable in slim.get_model_variables():
            summaries.append(tf.summary.histogram(variable.op.name, variable))
        summaries.append(
            tf.summary.scalar('training/Learning Rate', learning_rate))
        return summaries

    def _get_init_fn():
        """Returns a function run by the chief worker to warm-start the training.

    Note that the init_fn is only run when initializing the model during the very
    first global step.

    Returns:
      An init function run by the supervisor.
    """
        if FLAGS.checkpoint_path is None:
            return None

        # Warn the user if a checkpoint exists in the train_dir. Then we'll be
        # ignoring the checkpoint anyway.
        if tf.train.latest_checkpoint(FLAGS.train_dir):
            tf.logging.info(
                'Ignoring --checkpoint_path because a checkpoint already exists in %s'
                % FLAGS.train_dir)
            return None

        exclusions = []
        if FLAGS.checkpoint_exclude_scopes:
            exclusions = [
                scope.strip()
                for scope in FLAGS.checkpoint_exclude_scopes.split(',')
            ]

        # TODO(sguada) variables.filter_variables()
        variables_to_restore = []
        for var in slim.get_model_variables():
            excluded = False
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path

        tf.logging.info('Fine-tuning from %s' % checkpoint_path)

        return slim.assign_from_checkpoint_fn(
            checkpoint_path,
            variables_to_restore,
            ignore_missing_vars=FLAGS.ignore_missing_vars)

    def _get_variables_to_train():
        """Returns a list of variables to train.

    Returns:
      A list of variables to train by the optimizer.
    """
        if FLAGS.trainable_scopes is None:
            return tf.trainable_variables()
        else:
            scopes = [
                scope.strip() for scope in FLAGS.trainable_scopes.split(',')
            ]

        variables_to_train = []
        for scope in scopes:
            variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope)
            variables_to_train.extend(variables)
        return variables_to_train

    # main
    cluster_spec, server = TFNode.start_cluster_server(ctx, FLAGS.num_gpus)
    if ctx.job_name == 'ps':
        # `ps` jobs wait for incoming connections from the workers.
        server.join()
    else:
        # `worker` jobs will actually do the work.
        if not FLAGS.dataset_dir:
            raise ValueError(
                'You must supply the dataset directory with --dataset_dir')

        tf.logging.set_verbosity(tf.logging.INFO)
        with tf.Graph().as_default():
            ######################
            # Config model_deploy#
            ######################
            deploy_config = model_deploy.DeploymentConfig(
                num_clones=FLAGS.num_clones,
                clone_on_cpu=FLAGS.clone_on_cpu,
                replica_id=FLAGS.task,
                num_replicas=FLAGS.worker_replicas,
                num_ps_tasks=FLAGS.num_ps_tasks)

            # Create global_step
            with tf.device(deploy_config.variables_device()):
                global_step = slim.create_global_step()

            ######################
            # Select the dataset #
            ######################
            dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                  FLAGS.dataset_split_name,
                                                  FLAGS.dataset_dir)

            ####################
            # Select the network #
            ####################
            network_fn = nets_factory.get_network_fn(
                FLAGS.model_name,
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=True)

            #####################################
            # Select the preprocessing function #
            #####################################
            preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name, is_training=True)

            ##############################################################
            # Create a dataset provider that loads data from the dataset #
            ##############################################################
            with tf.device(deploy_config.inputs_device()):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size)
                [image, label] = provider.get(['image', 'label'])
                label -= FLAGS.labels_offset

                train_image_size = FLAGS.train_image_size or network_fn.default_image_size

                image = image_preprocessing_fn(image, train_image_size,
                                               train_image_size)

                images, labels = tf.train.batch(
                    [image, label],
                    batch_size=FLAGS.batch_size,
                    num_threads=FLAGS.num_preprocessing_threads,
                    capacity=5 * FLAGS.batch_size)
                labels = slim.one_hot_encoding(
                    labels, dataset.num_classes - FLAGS.labels_offset)
                batch_queue = slim.prefetch_queue.prefetch_queue(
                    [images, labels], capacity=2 * deploy_config.num_clones)

            ####################
            # Define the model #
            ####################
            def clone_fn(batch_queue):
                """Allows data parallelism by creating multiple clones of network_fn."""
                images, labels = batch_queue.dequeue()
                logits, end_points = network_fn(images)

                #############################
                # Specify the loss function #
                #############################
                if 'AuxLogits' in end_points:
                    slim.losses.softmax_cross_entropy(
                        end_points['AuxLogits'],
                        labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=0.4,
                        scope='aux_loss')
                slim.losses.softmax_cross_entropy(
                    logits,
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=1.0)
                return end_points

            # Gather initial summaries.
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            clones = model_deploy.create_clones(deploy_config, clone_fn,
                                                [batch_queue])
            first_clone_scope = deploy_config.clone_scope(0)
            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by network_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

            # Add summaries for end_points.
            end_points = clones[0].outputs
            for end_point in end_points:
                x = end_points[end_point]
                summaries.add(
                    tf.summary.histogram('activations/' + end_point, x))
                summaries.add(
                    tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

            # Add summaries for losses.
            for loss in tf.get_collection(tf.GraphKeys.LOSSES,
                                          first_clone_scope):
                summaries.add(
                    tf.summary.scalar('losses/%s' % loss.op.name, loss))

            # Add summaries for variables.
            for variable in slim.get_model_variables():
                summaries.add(tf.summary.histogram(variable.op.name, variable))

            #################################
            # Configure the moving averages #
            #################################
            if FLAGS.moving_average_decay:
                moving_average_variables = slim.get_model_variables()
                variable_averages = tf.train.ExponentialMovingAverage(
                    FLAGS.moving_average_decay, global_step)
            else:
                moving_average_variables, variable_averages = None, None

            #########################################
            # Configure the optimization procedure. #
            #########################################
            with tf.device(deploy_config.optimizer_device()):
                learning_rate = _configure_learning_rate(
                    dataset.num_samples, global_step)
                optimizer = _configure_optimizer(learning_rate)
                summaries.add(tf.summary.scalar('learning_rate',
                                                learning_rate))

            if FLAGS.sync_replicas:
                # If sync_replicas is enabled, the averaging will be done in the chief
                # queue runner.
                optimizer = tf.train.SyncReplicasOptimizer(
                    opt=optimizer,
                    replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                    variable_averages=variable_averages,
                    variables_to_average=moving_average_variables,
                    replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                    total_num_replicas=FLAGS.worker_replicas)
            elif FLAGS.moving_average_decay:
                # Update ops executed locally by trainer.
                update_ops.append(
                    variable_averages.apply(moving_average_variables))

            # Variables to train.
            variables_to_train = _get_variables_to_train()

            #  and returns a train_tensor and summary_op
            total_loss, clones_gradients = model_deploy.optimize_clones(
                clones, optimizer, var_list=variables_to_train)
            # Add total_loss to summary.
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Create gradient updates.
            grad_updates = optimizer.apply_gradients(clones_gradients,
                                                     global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            train_tensor = control_flow_ops.with_dependencies([update_op],
                                                              total_loss,
                                                              name='train_op')

            # Add the summaries from the first clone. These contain the summaries
            # created by model_fn and either optimize_clones() or _gather_clone_loss().
            summaries |= set(
                tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

            # Merge all summaries together.
            summary_op = tf.summary.merge(list(summaries), name='summary_op')

            ###########################
            # Kicks off the training. #
            ###########################
            summary_writer = tf.summary.FileWriter(
                "tensorboard_%d" % (ctx.worker_num),
                graph=tf.get_default_graph())
            slim.learning.train(
                train_tensor,
                logdir=FLAGS.train_dir,
                master=server.target,
                is_chief=(FLAGS.task == 0),
                init_fn=_get_init_fn(),
                summary_op=summary_op,
                number_of_steps=FLAGS.max_number_of_steps,
                log_every_n_steps=FLAGS.log_every_n_steps,
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                summary_writer=summary_writer,
                sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Exemplo n.º 9
0
def main_fun(argv, ctx):

    import tensorflow as tf
    import cifar10

    sys.argv = argv
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
                               """Directory where to write event logs.""")
    tf.app.flags.DEFINE_string('eval_data', 'test',
                               """Either 'test' or 'train_eval'.""")
    tf.app.flags.DEFINE_string(
        'checkpoint_dir', '/tmp/cifar10_train',
        """Directory where to read model checkpoints.""")
    tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
                                """How often to run the eval.""")
    tf.app.flags.DEFINE_integer('num_examples', 10000,
                                """Number of examples to run.""")
    tf.app.flags.DEFINE_boolean('run_once', False,
                                """Whether to run eval only once.""")
    tf.app.flags.DEFINE_boolean('rdma', False, """Whether to use rdma.""")

    cluster_spec, server = TFNode.start_cluster_server(ctx, 1, FLAGS.rdma)

    def eval_once(saver, summary_writer, top_k_op, summary_op):
        """Run Eval once.

    Args:
      saver: Saver.
      summary_writer: Summary writer.
      top_k_op: Top K op.
      summary_op: Summary op.
    """
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                # Restores from checkpoint
                saver.restore(sess, ckpt.model_checkpoint_path)
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return

            # Start the queue runners.
            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))

                num_iter = int(math.ceil(FLAGS.num_examples /
                                         FLAGS.batch_size))
                true_count = 0  # Counts the number of correct predictions.
                total_sample_count = num_iter * FLAGS.batch_size
                step = 0
                while step < num_iter and not coord.should_stop():
                    predictions = sess.run([top_k_op])
                    true_count += np.sum(predictions)
                    step += 1

                # Compute precision @ 1.
                precision = true_count / total_sample_count
                print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                summary.value.add(tag='Precision @ 1', simple_value=precision)
                summary_writer.add_summary(summary, global_step)
            except Exception as e:  # pylint: disable=broad-except
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)

    def evaluate():
        """Eval CIFAR-10 for a number of steps."""
        with tf.Graph().as_default() as g:
            # Get images and labels for CIFAR-10.
            eval_data = FLAGS.eval_data == 'test'
            images, labels = cifar10.inputs(eval_data=eval_data)

            # Build a Graph that computes the logits predictions from the
            # inference model.
            logits = cifar10.inference(images)

            # Calculate predictions.
            top_k_op = tf.nn.in_top_k(logits, labels, 1)

            # Restore the moving average version of the learned variables for eval.
            variable_averages = tf.train.ExponentialMovingAverage(
                cifar10.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            # Build the summary operation based on the TF collection of Summaries.
            summary_op = tf.summary.merge_all()

            summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

            while True:
                eval_once(saver, summary_writer, top_k_op, summary_op)
                if FLAGS.run_once:
                    break
                time.sleep(FLAGS.eval_interval_secs)

    #cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.DeleteRecursively(FLAGS.eval_dir)
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    evaluate()
Exemplo n.º 10
0
def main_fun(argv, ctx):
    import tensorflow as tf
    import cifar10

    sys.argv = argv
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string(
        'train_dir', '/tmp/cifar10_train',
        """Directory where to write event logs """
        """and checkpoint.""")
    tf.app.flags.DEFINE_integer('max_steps', 1000000,
                                """Number of batches to run.""")
    tf.app.flags.DEFINE_integer('num_gpus', 1, """How many GPUs to use.""")
    tf.app.flags.DEFINE_boolean('log_device_placement', False,
                                """Whether to log device placement.""")
    tf.app.flags.DEFINE_boolean('rdma', False, """Whether to use rdma.""")
    cluster_spec, server = TFNode.start_cluster_server(ctx, FLAGS.num_gpus,
                                                       FLAGS.rdma)

    def tower_loss(scope):
        """Calculate the total loss on a single tower running the CIFAR model.

    Args:
      scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'

    Returns:
       Tensor of shape [] containing the total loss for a batch of data
    """
        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build inference Graph.
        logits = cifar10.inference(images)

        # Build the portion of the Graph calculating the losses. Note that we will
        # assemble the total_loss using a custom function below.
        _ = cifar10.loss(logits, labels)

        # Assemble all of the losses for the current tower only.
        losses = tf.get_collection('losses', scope)

        # Calculate the total loss for the current tower.
        total_loss = tf.add_n(losses, name='total_loss')

        # Attach a scalar summary to all individual losses and the total loss; do the
        # same for the averaged version of the losses.
        for l in losses + [total_loss]:
            # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
            # session. This helps the clarity of presentation on tensorboard.
            loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '',
                               l.op.name)
            tf.summary.scalar(loss_name, l)

        return total_loss

    def average_gradients(tower_grads):
        """Calculate the average gradient for each shared variable across all towers.

    Note that this function provides a synchronization point across all towers.

    Args:
      tower_grads: List of lists of (gradient, variable) tuples. The outer list
        is over individual gradients. The inner list is over the gradient
        calculation for each tower.
    Returns:
       List of pairs of (gradient, variable) where the gradient has been averaged
       across all towers.
    """
        average_grads = []
        for grad_and_vars in zip(*tower_grads):
            # Note that each grad_and_vars looks like the following:
            #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
            grads = []
            for g, _ in grad_and_vars:
                # Add 0 dimension to the gradients to represent the tower.
                expanded_g = tf.expand_dims(g, 0)

                # Append on a 'tower' dimension which we will average over below.
                grads.append(expanded_g)

            # Average over the 'tower' dimension.
            grad = tf.concat(axis=0, values=grads)
            grad = tf.reduce_mean(grad, 0)

            # Keep in mind that the Variables are redundant because they are shared
            # across towers. So .. we will just return the first tower's pointer to
            # the Variable.
            v = grad_and_vars[0][1]
            grad_and_var = (grad, v)
            average_grads.append(grad_and_var)
        return average_grads

    def train():
        """Train CIFAR-10 for a number of steps."""
        with tf.Graph().as_default(), tf.device('/cpu:0'):
            # Create a variable to count the number of train() calls. This equals the
            # number of batches processed * FLAGS.num_gpus.
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)

            # Calculate the learning rate schedule.
            num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                                     FLAGS.batch_size)
            decay_steps = int(num_batches_per_epoch *
                              cifar10.NUM_EPOCHS_PER_DECAY)

            # Decay the learning rate exponentially based on the number of steps.
            lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
                                            global_step,
                                            decay_steps,
                                            cifar10.LEARNING_RATE_DECAY_FACTOR,
                                            staircase=True)

            # Create an optimizer that performs gradient descent.
            opt = tf.train.GradientDescentOptimizer(lr)

            # Calculate the gradients for each model tower.
            tower_grads = []
            with tf.variable_scope(tf.get_variable_scope()):
                for i in xrange(FLAGS.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        with tf.name_scope('%s_%d' %
                                           (cifar10.TOWER_NAME, i)) as scope:
                            # Calculate the loss for one tower of the CIFAR model. This function
                            # constructs the entire CIFAR model but shares the variables across
                            # all towers.
                            loss = tower_loss(scope)

                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

                            # Retain the summaries from the final tower.
                            summaries = tf.get_collection(
                                tf.GraphKeys.SUMMARIES, scope)

                            # Calculate the gradients for the batch of data on this CIFAR tower.
                            grads = opt.compute_gradients(loss)

                            # Keep track of the gradients across all towers.
                            tower_grads.append(grads)

            # We must calculate the mean of each gradient. Note that this is the
            # synchronization point across all towers.
            grads = average_gradients(tower_grads)

            # Add a summary to track the learning rate.
            summaries.append(tf.summary.scalar('learning_rate', lr))

            # Add histograms for gradients.
            for grad, var in grads:
                if grad is not None:
                    summaries.append(
                        tf.summary.histogram(var.op.name + '/gradients', grad))

            # Apply the gradients to adjust the shared variables.
            apply_gradient_op = opt.apply_gradients(grads,
                                                    global_step=global_step)

            # Add histograms for trainable variables.
            for var in tf.trainable_variables():
                summaries.append(tf.summary.histogram(var.op.name, var))

            # Track the moving averages of all trainable variables.
            variable_averages = tf.train.ExponentialMovingAverage(
                cifar10.MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(
                tf.trainable_variables())

            # Group all updates to into a single train op.
            train_op = tf.group(apply_gradient_op, variables_averages_op)

            # Create a saver.
            saver = tf.train.Saver(tf.global_variables())

            # Build the summary operation from the last tower summaries.
            summary_op = tf.summary.merge(summaries)

            # Build an initialization operation to run below.
            init = tf.global_variables_initializer()

            # Start running operations on the Graph. allow_soft_placement must be set to
            # True to build towers on GPU, as some of the ops do not have GPU
            # implementations.
            sess = tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement))
            sess.run(init)

            # Start the queue runners.
            tf.train.start_queue_runners(sess=sess)

            summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

            for step in xrange(FLAGS.max_steps):
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % 10 == 0:
                    num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / FLAGS.num_gpus

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), step, loss_value,
                                        examples_per_sec, sec_per_batch))

                if step % 100 == 0:
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)

                # Save the model checkpoint periodically.
                if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

    # cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    train()
Exemplo n.º 11
0
def preprocess_data_and_save(epochs, batch_size, keep_probability, ctx):
    # Get TF cluster and server instances
    cluster, server = TFNode.start_cluster_server(ctx, 1, args.rdma)

    if job_name == "ps":
        server.join()
    elif job_name == "worker":
        # Assigns ops to the local worker by default.
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % task_index,
                    cluster=cluster)):

            # Preprocess Training, Validation, and Testing Data
            helper.preprocess_and_save_data(cifar10_dataset_folder_path,
                                            normalize, one_hot_encode)

            # Load the Preprocessed Validation data
            valid_features, valid_labels = pickle.load(
                open('preprocess_validation.p', mode='rb'))

            # Remove previous weights, bias, inputs, etc..
            tf.reset_default_graph()

            # Inputs
            x = neural_net_image_input((32, 32, 3))
            y = neural_net_label_input(10)
            keep_prob = neural_net_keep_prob_input()

            # Model
            logits = conv_net(x, keep_prob)

            # Name logits Tensor, so that is can be loaded from disk after training
            logits = tf.identity(logits, name='logits')

            # Loss and Optimizer
            global_step = tf.Variable(0)
            cost = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                        labels=y))
            optimizer = tf.train.AdamOptimizer(1e-4).minimize(
                cost, global_step=global_step)

            # Accuracy
            correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32),
                                      name='accuracy')

            save_model_path = './image_classification'

            # Save Model
            saver = tf.train.Saver()
            summary_op = tf.summary.merge_all()
            init_op = tf.global_variables_initializer()

        # Create a "supervisor", which oversees the training process and stores model state into HDFS
        logdir = TFNode.hdfs_path(ctx, args.model)
        print("tensorflow model path: {0}".format(logdir))
        summary_writer = tf.summary.FileWriter("tensorboard_%d" % (worker_num),
                                               graph=tf.get_default_graph())
        sv = tf.train.Supervisor(is_chief=(task_index == 0),
                                 logdir=logdir,
                                 init_op=init_op,
                                 summary_op=summary_op,
                                 global_step=global_step,
                                 summary_writer=summary_writer,
                                 saver=saver,
                                 save_model_secs=10)

        print('Training...')
        with sv.managed_session(server.target) as sess:
            # Initializing the variables
            # init_op = tf.global_variables_initializer()
            # sess.run(tf.global_variables_initializer())

            # Training cycle
            for epoch in range(epochs):
                # Loop over all batches
                n_batches = 5
                for batch_i in range(1, n_batches + 1):
                    for batch_features, batch_labels in helper.load_preprocess_training_batch(
                            batch_i, batch_size):
                        train_neural_network(sess, optimizer, keep_probability,
                                             batch_features, batch_labels, x,
                                             y, keep_prob)


#                    print('Epoch {:>2}, CIFAR-10 Batch {}:  '.format(epoch + 1, batch_i), end='')
                    print_stats(sess, batch_features, batch_labels, cost,
                                accuracy, x, y, keep_prob, valid_features,
                                valid_labels)

            # Save Model
            save_path = saver.save(sess, save_model_path)

        # Ask for all the services to stop.
        print("{0} stopping supervisor".format(datetime.now().isoformat()))
        sv.stop()
Exemplo n.º 12
0
def main_fun(argv, ctx):
    import tensorflow as tf
    import cifar10

    sys.argv = argv
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string(
        'train_dir', '/tmp/cifar10_train',
        """Directory where to write event logs """
        """and checkpoint.""")
    tf.app.flags.DEFINE_integer('max_steps', 1000000,
                                """Number of batches to run.""")
    tf.app.flags.DEFINE_boolean('log_device_placement', False,
                                """Whether to log device placement.""")
    tf.app.flags.DEFINE_boolean('rdma', False, """Whether to use rdma.""")

    # cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    cluster_spec, server = TFNode.start_cluster_server(ctx, 1, FLAGS.rdma)

    # Train CIFAR-10 for a number of steps.
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1

            def before_run(self, run_context):
                self._step += 1
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step % 10 == 0:
                    num_examples_per_step = FLAGS.batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
def train(target, dataset, cluster_spec, ctx):
    """Train Inception on a dataset for a number of steps."""
    # Number of workers and parameter servers are infered from the workers and ps
    # hosts string.
    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])
    # If no value is given, num_replicas_to_aggregate defaults to be the number of
    # workers.
    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    # Both should be greater than 0 in a distributed training.
    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')

    # Choose worker 0 as the chief. Note that any worker could be the chief
    # but there should be only one chief.
    is_chief = (FLAGS.task_id == 0)

    # Ops are assigned to worker by default.
    with tf.device('/job:worker/task:%d' % FLAGS.task_id):
        # Variables and its related init/assign ops are assigned to ps.
        with slim.scopes.arg_scope(
            [slim.variables.variable, slim.variables.global_step],
                device=slim.variables.VariableDeviceChooser(
                    num_parameter_servers)):
            # Create a variable to count the number of train() calls. This equals the
            # number of updates applied to the variables.
            global_step = slim.variables.global_step()

            # Calculate the learning rate schedule.
            num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                     FLAGS.batch_size)
            # Decay steps need to be divided by the number of replicas to aggregate.
            decay_steps = int(num_batches_per_epoch *
                              FLAGS.num_epochs_per_decay /
                              num_replicas_to_aggregate)

            # Decay the learning rate exponentially based on the number of steps.
            lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                            global_step,
                                            decay_steps,
                                            FLAGS.learning_rate_decay_factor,
                                            staircase=True)
            # Add a summary to track the learning rate.
            tf.summary.scalar('learning_rate', lr)

            # Create an optimizer that performs gradient descent.
            opt = tf.train.RMSPropOptimizer(lr,
                                            RMSPROP_DECAY,
                                            momentum=RMSPROP_MOMENTUM,
                                            epsilon=RMSPROP_EPSILON)

            if FLAGS.input_mode == 'spark':

                def feed_dict(mgr, batch_size):
                    tmp = TFNode.next_batch(mgr, batch_size)
                    # extract TFRecords, since tmp array is [(TFRecord, None)]
                    tfrecords = []
                    for elem in tmp:
                        tfrecords.append(str(elem[0]))
                    return tfrecords

                batch = tf.placeholder(
                    tf.string,
                    [FLAGS.batch_size / FLAGS.num_preprocess_threads])

                # The following is adapted from image_processing.py to remove Readers/QueueRunners.
                # Note: this removes the RandomShuffledQueue, so the incoming data is not shuffled.
                # Presumably, this could be done on the Spark side or done in additional TF code.
                examples = tf.unpack(batch)
                images, labels = [], []
                for example_serialized in examples:
                    for thread_id in range(FLAGS.num_preprocess_threads):
                        # Parse a serialized Example proto to extract the image and metadata.
                        image_buffer, label_index, bbox, _ = image_processing.parse_example_proto(
                            example_serialized)
                        image = image_processing.image_preprocessing(
                            image_buffer, bbox, train, thread_id)
                        images.append(image)
                        labels.append(label_index)
                height = FLAGS.image_size
                width = FLAGS.image_size
                depth = 3
                images = tf.cast(images, tf.float32)
                images = tf.reshape(
                    images, shape=[FLAGS.batch_size, height, width, depth])
                tf.summary.image('images', images)
                labels = tf.reshape(labels, [FLAGS.batch_size])
            else:
                images, labels = image_processing.distorted_inputs(
                    dataset,
                    batch_size=FLAGS.batch_size,
                    num_preprocess_threads=FLAGS.num_preprocess_threads)

            # Number of classes in the Dataset label set plus 1.
            # Label 0 is reserved for an (unused) background class.
            num_classes = dataset.num_classes() + 1
            logits = inception.inference(images,
                                         num_classes,
                                         for_training=True)
            # Add classification loss.
            inception.loss(logits, labels)

            # Gather all of the losses including regularization losses.
            losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
            losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

            total_loss = tf.add_n(losses, name='total_loss')

            if is_chief:
                # Compute the moving average of all individual losses and the
                # total loss.
                loss_averages = tf.train.ExponentialMovingAverage(0.9,
                                                                  name='avg')
                loss_averages_op = loss_averages.apply(losses + [total_loss])

                # Attach a scalar summmary to all individual losses and the total loss;
                # do the same for the averaged version of the losses.
                for l in losses + [total_loss]:
                    loss_name = l.op.name
                    # Name each loss as '(raw)' and name the moving average version of the
                    # loss as the original loss name.
                    tf.summary.scalar(loss_name + ' (raw)', l)
                    tf.summary.scalar(loss_name, loss_averages.average(l))

                # Add dependency to compute loss_averages.
                with tf.control_dependencies([loss_averages_op]):
                    total_loss = tf.identity(total_loss)

            # Track the moving averages of all trainable variables.
            # Note that we maintain a 'double-average' of the BatchNormalization
            # global statistics.
            # This is not needed when the number of replicas are small but important
            # for synchronous distributed training with tens of workers/replicas.
            exp_moving_averager = tf.train.ExponentialMovingAverage(
                inception.MOVING_AVERAGE_DECAY, global_step)

            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())

            # Add histograms for model variables.
            for var in variables_to_average:
                tf.summary.histogram(var.op.name, var)

            # Create synchronous replica optimizer.
            opt = tf.train.SyncReplicasOptimizer(
                opt,
                replicas_to_aggregate=num_replicas_to_aggregate,
                replica_id=FLAGS.task_id,
                total_num_replicas=num_workers,
                variable_averages=exp_moving_averager,
                variables_to_average=variables_to_average)

            batchnorm_updates = tf.get_collection(
                slim.ops.UPDATE_OPS_COLLECTION)
            assert batchnorm_updates, 'Batchnorm updates are missing'
            batchnorm_updates_op = tf.group(*batchnorm_updates)
            # Add dependency to compute batchnorm_updates.
            with tf.control_dependencies([batchnorm_updates_op]):
                total_loss = tf.identity(total_loss)

            # Compute gradients with respect to the loss.
            grads = opt.compute_gradients(total_loss)

            # Add histograms for gradients.
            for grad, var in grads:
                if grad is not None:
                    tf.summary.histogram(var.op.name + '/gradients', grad)

            apply_gradients_op = opt.apply_gradients(grads,
                                                     global_step=global_step)

            with tf.control_dependencies([apply_gradients_op]):
                train_op = tf.identity(total_loss, name='train_op')

            # Get chief queue_runners, init_tokens and clean_up_op, which is used to
            # synchronize replicas.
            # More details can be found in sync_replicas_optimizer.
            chief_queue_runners = [opt.get_chief_queue_runner()]
            init_tokens_op = opt.get_init_tokens_op()
            clean_up_op = opt.get_clean_up_op()

            # Create a saver.
            saver = tf.train.Saver()

            # Build the summary operation based on the TF collection of Summaries.
            summary_op = tf.summary.merge_all()

            # Build an initialization operation to run below.
            init_op = tf.global_variables_initializer()

            # We run the summaries in the same thread as the training operations by
            # passing in None for summary_op to avoid a summary_thread being started.
            # Running summaries and training operations in parallel could run out of
            # GPU memory.
            summary_writer = tf.summary.FileWriter(
                "tensorboard_%d" % (ctx.worker_num),
                graph=tf.get_default_graph())
            sv = tf.train.Supervisor(is_chief=is_chief,
                                     logdir=FLAGS.train_dir,
                                     init_op=init_op,
                                     summary_op=None,
                                     global_step=global_step,
                                     summary_writer=summary_writer,
                                     saver=saver,
                                     save_model_secs=FLAGS.save_interval_secs)

            tf.logging.info('%s Supervisor' % datetime.now())

            sess_config = tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)

            # Get a session.
            sess = sv.prepare_or_wait_for_session(target, config=sess_config)

            # Start the queue runners.
            queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
            sv.start_queue_runners(sess, queue_runners)
            tf.logging.info('Started %d queues for processing input data.',
                            len(queue_runners))

            if is_chief:
                sv.start_queue_runners(sess, chief_queue_runners)
                sess.run(init_tokens_op)

            # Train, checking for Nans. Concurrently run the summary operation at a
            # specified interval. Note that the summary_op and train_op never run
            # simultaneously in order to prevent running out of GPU memory.
            next_summary_time = time.time() + FLAGS.save_summaries_secs
            while not sv.should_stop():
                try:
                    start_time = time.time()
                    if FLAGS.input_mode == 'spark':
                        tmp = feed_dict(
                            ctx.mgr,
                            FLAGS.batch_size / FLAGS.num_preprocess_threads)
                        feed = {batch: tmp}
                        loss_value, step = sess.run([train_op, global_step],
                                                    feed_dict=feed)
                    else:
                        loss_value, step = sess.run([train_op, global_step])
                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'
                    if step > FLAGS.max_steps:
                        break
                    duration = time.time() - start_time

                    if step % 30 == 0:
                        examples_per_sec = FLAGS.batch_size / float(duration)
                        format_str = ('Worker %d: %s: step %d, loss = %.2f'
                                      '(%.1f examples/sec; %.3f  sec/batch)')
                        tf.logging.info(
                            format_str %
                            (FLAGS.task_id, datetime.now(), step, loss_value,
                             examples_per_sec, duration))

                    # Determine if the summary_op should be run on the chief worker.
                    if FLAGS.input_mode == 'tf' and is_chief and next_summary_time < time.time(
                    ):
                        tf.logging.info(
                            'Running Summary operation on the chief.')
                        summary_str = sess.run(summary_op)
                        sv.summary_computed(sess, summary_str)
                        tf.logging.info('Finished running Summary operation.')

                        # Determine the next time for running the summary.
                        next_summary_time += FLAGS.save_summaries_secs
                except:
                    if is_chief:
                        tf.logging.info('About to execute sync_clean_up_op!')
                        sess.run(clean_up_op)
                    raise

            # Stop the TFNode data feed
            if FLAGS.input_mode == 'spark':
                TFNode.terminate(ctx.mgr)

            # Stop the supervisor.  This also waits for service threads to finish.
            sv.stop()

            # Save after the training ends.
            if is_chief:
                saver.save(sess,
                           os.path.join(FLAGS.train_dir, 'model.ckpt'),
                           global_step=global_step)