示例#1
0
    def _tf_export(args):
      """Creates an inference graph w/ placeholder and loads weights from checkpoint"""
      import tensorflow as tf
      from tensorflowonspark import TFNode

      tf.reset_default_graph()                          # reset graph in case we're re-using a Spark python worker
      x = tf.placeholder(tf.float32, [None, 2], name='x')
      w = tf.Variable(tf.truncated_normal([2,1]), name='w')
      y = tf.matmul(x, w, name='y')
      y2 = tf.square(y, name="y2")                      # extra/optional output for testing multiple output tensors
      saver = tf.train.Saver()

      with tf.Session() as sess:
        # load graph from a checkpoint
        ckpt = tf.train.get_checkpoint_state(args.model_dir)
        assert ckpt and ckpt.model_checkpoint_path, "Invalid model checkpoint path: {}".format(args.model_dir)
        saver.restore(sess, ckpt.model_checkpoint_path)

        # exported signatures defined in code
        signatures = {
          'test_key': {
            'inputs': { 'features': x },
            'outputs': { 'prediction': y, 'pred2': y2 },
            'method_name': 'test'
          }
        }
        TFNode.export_saved_model(sess, export_dir=args.export_dir, tag_set='test_tag', signatures=signatures)
示例#2
0
    def end(self, session):
        if self.mode != 'train':
            return

        print("{} ======= Exporting to: {}".format(datetime.now().isoformat(),
                                                   self.export_dir))
        signatures = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            {
                'inputs': {
                    'image': self.input_tensor
                },
                'outputs': {
                    'prediction': self.output_tensor
                },
                'method_name':
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            }
        }

        # 保存和导出模型
        TFNode.export_saved_model(session, self.export_dir,
                                  tf.saved_model.tag_constants.SERVING,
                                  signatures)
        print("{} ======= Done exporting".format(datetime.now().isoformat()))
示例#3
0
    def _spark_train(args, ctx):
      """Basic linear regression in a distributed TF cluster using InputMode.SPARK"""
      import tensorflow as tf
      from tensorflowonspark import TFNode

      tf.reset_default_graph()                          # reset graph in case we're re-using a Spark python worker

      cluster, server = TFNode.start_cluster_server(ctx)
      if ctx.job_name == "ps":
        server.join()
      elif ctx.job_name == "worker":
        with tf.device(tf.train.replica_device_setter(
          worker_device="/job:worker/task:%d" % ctx.task_index,
          cluster=cluster)):
          x = tf.placeholder(tf.float32, [None, 2], name='x')
          y_ = tf.placeholder(tf.float32, [None, 1], name='y_')
          w = tf.Variable(tf.truncated_normal([2,1]), name='w')
          y = tf.matmul(x, w, name='y')
          y2 = tf.square(y, name="y2")                      # extra/optional output for testing multiple output tensors
          cost = tf.reduce_mean(tf.square(y_ - y), name='cost')
          optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cost)
          init_op = tf.global_variables_initializer()
          saver = tf.train.Saver()

        sv = tf.train.Supervisor(is_chief=(ctx.task_index == 0),
                                init_op=init_op)
        with sv.managed_session(server.target) as sess:
          tf_feed = TFNode.DataFeed(ctx.mgr, input_mapping=args.input_mapping)
          while not sv.should_stop() and not tf_feed.should_stop():
            batch = tf_feed.next_batch(10)
            if args.input_mapping:
              if len(batch['x']) > 0:
                feed = { x: batch['x'], y_: batch['y_'] }
              opt = sess.run(optimizer, feed_dict=feed)

          if sv.is_chief:
            if args.model_dir:
              # manually save checkpoint
              ckpt_name = args.model_dir + "/model.ckpt"
              print("Saving checkpoint to: {}".format(ckpt_name))
              saver.save(sess, ckpt_name)
            elif args.export_dir:
              # export a saved_model
              signatures = {
                'test_key': {
                  'inputs': { 'features': x },
                  'outputs': { 'prediction': y },
                  'method_name': 'test'
                }
              }
              TFNode.export_saved_model(sess, export_dir=args.export_dir, tag_set='test_tag', signatures=signatures)
            else:
              print("WARNING: model state not saved.")

        sv.stop()
示例#4
0
def export_fun(args):
  """Define/export a single-node TF graph for inferencing"""
  # Input placeholder for inferencing
  x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")

  # Variables of the hidden layer
  hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
                      stddev=1.0 / IMAGE_PIXELS), name="hid_w")
  hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
  tf.summary.histogram("hidden_weights", hid_w)

  # Variables of the softmax layer
  sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
                     stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
  sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

  hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
  hid = tf.nn.relu(hid_lin)
  y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
  prediction = tf.argmax(y, 1, name="prediction")

  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load graph from a checkpoint
    logging.info("model path: {}".format(args.model_dir))
    ckpt = tf.train.get_checkpoint_state(args.model_dir)
    logging.info("ckpt: {}".format(ckpt))
    assert ckpt and ckpt.model_checkpoint_path, "Invalid model checkpoint path: {}".format(args.model_dir)
    saver.restore(sess, ckpt.model_checkpoint_path)

    logging.info("Exporting saved_model to: {}".format(args.export_dir))
    # exported signatures defined in code
    signatures = {
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
        'inputs': {'image': x},
        'outputs': {'prediction': prediction},
        'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
      },
      'featurize': {
        'inputs': {'image': x},
        'outputs': {'features': hid},
        'method_name': 'featurize'
      }
    }
    TFNode.export_saved_model(sess,
                              args.export_dir,
                              tf.saved_model.tag_constants.SERVING,
                              signatures)
    logging.info("Exported saved_model")
示例#5
0
 def end(self, session):
     print("{} ======= Exporting to: {}".format(
         datetime.now().isoformat(), self.export_dir))
     signatures = {
         "test_key": {
             'inputs': {
                 'features': self.input_tensor
             },
             'outputs': {
                 'prediction': self.output_tensor
             },
             'method_name':
             tf.saved_model.signature_constants.
             PREDICT_METHOD_NAME
         }
     }
     TFNode.export_saved_model(session, self.export_dir,
                               "test_tag", signatures)
     print("{} ======= Done exporting".format(
         datetime.now().isoformat()))
示例#6
0
    def end(self, session):
        logging.info("{} ======= Exporting to: {}".format(
            datetime.now().isoformat(), self.export_dir))
        signatures = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            {
                'inputs': {
                    'image': self.input_tensor
                },
                'outputs': {
                    'prediction': self.output_tensor
                },
                'method_name':
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            }
        }

        TFNode.export_saved_model(session,
                                  self.export_dir + '_' + str(random.random()),
                                  tf.saved_model.tag_constants.SERVING,
                                  signatures)
        logging.info("{} ====== Done exporting".format(
            datetime.now().isoformat()))
def main(_):
  # restore graph/session from checkpoint
  sess = tf.Session(graph=tf.get_default_graph())
  ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
  saver = tf.train.import_meta_graph(ckpt + '.meta', clear_devices=True)
  saver.restore(sess, ckpt)
  g = sess.graph

  # if --show, dump out all operations in this graph
  if FLAGS.show:
    for o in g.get_operations():
      print("{:>64}\t{}".format(o.name, o.type))

  if FLAGS.export_dir and FLAGS.signatures:
    # load/parse JSON signatures
    if ':' in FLAGS.signatures:
      # assume JSON string, since unix filenames shouldn't contain colons
      signatures = json.loads(FLAGS.signatures)
    else:
      # assume JSON file
      with open(FLAGS.signatures) as f:
        signatures = json.load(f)

    # convert string input/output values with actual tensors from graph
    for name, sig in signatures.items():
      for k, v in sig['inputs'].items():
        tensor_name = v if v.endswith(':0') else v + ':0'
        sig['inputs'][k] = g.get_tensor_by_name(tensor_name)
      for k, v in sig['outputs'].items():
        tensor_name = v if v.endswith(':0') else v + ':0'
        sig['outputs'][k] = g.get_tensor_by_name(tensor_name)

    # export a saved model
    TFNode.export_saved_model(sess,
                              FLAGS.export_dir,
                              tf.saved_model.tag_constants.SERVING,
                              signatures)
示例#8
0
def save_model(sess, args, x, prediction):
    """ 保存模型 """

    pb_folder_dir = args.export_dir + constants.PATH_SEP + constants.PB_FOLDER_NAME
    # exported signatures defined in code
    signatures = {
        tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
            "inputs": {constants.SIG_INPUT: x},
            "outputs": {constants.SIG_OUTPUT: prediction},
            "method_name": tf.saved_model.signature_constants.PREDICT_METHOD_NAME
        }
    }
    TFNode.export_saved_model(sess,
                              pb_folder_dir,
                              tf.saved_model.tag_constants.SERVING,
                              signatures)

    # 转为单个pb
    t = Thread(target=tensorflow_utils.convert_as_single_pb,
               args=[pb_folder_dir,
                     constants.PREDICT_NODE_NAME,
                     args.export_dir + constants.PATH_SEP + constants.PB_NAME])
    t.start()
    t.join()
def map_fun(args, ctx):
  from tensorflowonspark import TFNode
  from datetime import datetime
  import math
  import numpy
  import tensorflow as tf
  import time

  worker_num = ctx.worker_num
  job_name = ctx.job_name
  task_index = ctx.task_index

  IMAGE_PIXELS = 28

  # Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict)
  if job_name == "ps":
    time.sleep((worker_num + 1) * 5)

  # Parameters
  hidden_units = 128
  batch_size = args.batch_size

  # Get TF cluster and server instances
  cluster, server = TFNode.start_cluster_server(ctx, 1, args.protocol == 'rdma')

  def feed_dict(batch):
    # Convert from dict of named arrays to two numpy arrays of the proper type
    images = batch['image']
    labels = batch['label']
    xs = numpy.array(images)
    xs = xs.astype(numpy.float32)
    xs = xs / 255.0
    ys = numpy.array(labels)
    ys = ys.astype(numpy.uint8)
    return (xs, ys)

  if job_name == "ps":
    server.join()
  elif job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
      worker_device="/job:worker/task:%d" % task_index,
      cluster=cluster)):

      # Variables of the hidden layer
      hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
                          stddev=1.0 / IMAGE_PIXELS), name="hid_w")
      hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
      tf.summary.histogram("hidden_weights", hid_w)

      # Variables of the softmax layer
      sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
                         stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
      sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
      tf.summary.histogram("softmax_weights", sm_w)

      # Placeholders or QueueRunner/Readers for input data
      x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
      y_ = tf.placeholder(tf.float32, [None, 10], name="y_")

      x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
      tf.summary.image("x_img", x_img)

      hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
      hid = tf.nn.relu(hid_lin)

      y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

      global_step = tf.Variable(0)

      loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
      tf.summary.scalar("loss", loss)

      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)

      # Test trained model
      label = tf.argmax(y_, 1, name="label")
      prediction = tf.argmax(y, 1, name="prediction")
      correct_prediction = tf.equal(prediction, label)

      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
      tf.summary.scalar("acc", accuracy)

      saver = tf.train.Saver()
      summary_op = tf.summary.merge_all()
      init_op = tf.global_variables_initializer()

    # Create a "supervisor", which oversees the training process and stores model state into HDFS
    logdir = TFNode.hdfs_path(ctx, args.model_dir)
    print("tensorflow model path: {0}".format(logdir))
    summary_writer = tf.summary.FileWriter("tensorboard_%d" % (worker_num), graph=tf.get_default_graph())

    sv = tf.train.Supervisor(is_chief=(task_index == 0),
                             logdir=logdir,
                             init_op=init_op,
                             summary_op=None,
                             saver=saver,
                             global_step=global_step,
                             stop_grace_secs=300,
                             save_model_secs=10)

    # The supervisor takes care of session initialization, restoring from
    # a checkpoint, and closing when done or an error occurs.
    with sv.managed_session(server.target) as sess:
      print("{0} session ready".format(datetime.now().isoformat()))

      # Loop until the supervisor shuts down or 1000000 steps have completed.
      step = 0
      tf_feed = TFNode.DataFeed(ctx.mgr, input_mapping=args.input_mapping)
      while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.

        # using feed_dict
        batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
        feed = {x: batch_xs, y_: batch_ys}

        if len(batch_xs) > 0:
          _, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
          # print accuracy and save model checkpoint to HDFS every 100 steps
          if (step % 100 == 0):
            print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy, {x: batch_xs, y_: batch_ys})))

          if sv.is_chief:
            summary_writer.add_summary(summary, step)

      if sv.should_stop() or step >= args.steps:
        tf_feed.terminate()

      if sv.is_chief and args.export_dir:
        print("{0} exporting saved_model to: {1}".format(datetime.now().isoformat(), args.export_dir))
        # exported signatures defined in code
        signatures = {
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
            'inputs': {'image': x},
            'outputs': {'prediction': prediction},
            'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          },
          'featurize': {
            'inputs': {'image': x},
            'outputs': {'features': hid},
            'method_name': 'featurize'
          }
        }
        TFNode.export_saved_model(sess,
                                  args.export_dir,
                                  tf.saved_model.tag_constants.SERVING,
                                  signatures)
      else:
        # non-chief workers should wait for chief
        while not sv.should_stop():
          print("Waiting for chief")
          time.sleep(5)

    # Ask for all the services to stop.
    print("{0} stopping supervisor".format(datetime.now().isoformat()))
    sv.stop()
def export(args):
    FLAGS = tf.app.flags.FLAGS
    """Evaluate model on Dataset for a number of steps."""
    #with tf.Graph().as_default():
    tf.reset_default_graph()

    def preprocess_image(image_buffer):
        """Preprocess JPEG encoded bytes to 3D float Tensor."""

        # Decode the string as an RGB JPEG.
        # Note that the resulting image contains an unknown height and width
        # that is set dynamically by decode_jpeg. In other words, the height
        # and width of image is unknown at compile-time.
        image = tf.image.decode_jpeg(image_buffer, channels=3)
        # After this point, all image pixels reside in [0,1)
        # until the very end, when they're rescaled to (-1, 1).  The various
        # adjust_* ops all require this range for dtype float.
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        # Crop the central region of the image with an area containing 87.5% of
        # the original image.
        image = tf.image.central_crop(image, central_fraction=0.875)
        # Resize the image to the original height and width.
        image = tf.expand_dims(image, 0)
        image = tf.image.resize_bilinear(image,
                                         [FLAGS.image_size, FLAGS.image_size],
                                         align_corners=False)
        image = tf.squeeze(image, [0])
        # Finally, rescale to [-1,1] instead of [0, 1)
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)
        return image

    # Get images and labels from the dataset.
    jpegs = tf.placeholder(tf.string, [None], name='jpegs')
    images = tf.map_fn(preprocess_image, jpegs, dtype=tf.float32)
    labels = tf.placeholder(tf.int32, [None], name='labels')

    # Number of classes in the Dataset label set plus 1.
    # Label 0 is reserved for an (unused) background class.
    dataset = ImagenetData(subset=FLAGS.subset)

    num_classes = dataset.num_classes() + 1

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits, _ = inception.inference(images, num_classes)

    # Calculate predictions.
    top_1_op = tf.nn.in_top_k(logits, labels, 1)
    top_5_op = tf.nn.in_top_k(logits, labels, 5)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        inception.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if not ckpt or not ckpt.model_checkpoint_path:
            raise Exception("No checkpoint file found at: {}".format(
                FLAGS.train_dir))
        print("ckpt.model_checkpoint_path: {0}".format(
            ckpt.model_checkpoint_path))

        saver.restore(sess, ckpt.model_checkpoint_path)

        # Assuming model_checkpoint_path looks something like:
        #   /my-favorite-path/imagenet_train/model.ckpt-0,
        # extract global_step from it.
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        print('Successfully loaded model from %s at step=%s.' %
              (ckpt.model_checkpoint_path, global_step))

        print("Exporting saved_model to: {}".format(args.export_dir))
        # exported signatures defined in code
        signatures = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            {
                'inputs': {
                    'jpegs': jpegs
                },
                'outputs': {
                    'logits': logits
                },
                'method_name':
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            }
        }
        TFNode.export_saved_model(sess, args.export_dir,
                                  tf.saved_model.tag_constants.SERVING,
                                  signatures)
        print("Exported saved_model")
示例#11
0
def map_fun(args, ctx):
  from datetime import datetime
  from tensorflowonspark import TFNode
  import math
  import os
  import tensorflow as tf
  import time

  num_workers = len(ctx.cluster_spec['worker'])
  worker_num = ctx.worker_num
  job_name = ctx.job_name
  task_index = ctx.task_index

  # Parameters
  IMAGE_PIXELS = 28
  hidden_units = 128

  # Get TF cluster and server instances
  cluster, server = ctx.start_cluster_server(1, args.rdma)

  def _parse_csv(ln):
    splits = tf.string_split([ln], delimiter='|')
    lbl = splits.values[0]
    img = splits.values[1]
    image_defaults = [[0.0] for col in range(IMAGE_PIXELS * IMAGE_PIXELS)]
    image = tf.stack(tf.decode_csv(img, record_defaults=image_defaults))
    norm = tf.constant(255, dtype=tf.float32, shape=(784,))
    normalized_image = tf.div(image, norm)
    label_value = tf.string_to_number(lbl, tf.int32)
    label = tf.one_hot(label_value, 10)
    return (normalized_image, label)

  def _parse_tfr(example_proto):
    feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
                   "image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
    features = tf.parse_single_example(example_proto, feature_def)
    norm = tf.constant(255, dtype=tf.float32, shape=(784,))
    image = tf.div(tf.to_float(features['image']), norm)
    label = tf.to_float(features['label'])
    return (image, label)

  def build_model(graph, x):
    with graph.as_default():
      # Variables of the hidden layer
      hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
                          stddev=1.0 / IMAGE_PIXELS), name="hid_w")
      hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
      tf.summary.histogram("hidden_weights", hid_w)

      # Variables of the softmax layer
      sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
                         stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
      sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
      tf.summary.histogram("softmax_weights", sm_w)

      hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
      hid = tf.nn.relu(hid_lin)

      y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
      prediction = tf.argmax(y, 1, name="prediction")
      return y, prediction

  if job_name == "ps":
    server.join()
  elif job_name == "worker":
    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
      worker_device="/job:worker/task:%d" % task_index,
      cluster=cluster)):

      # Dataset for input data
      image_dir = ctx.absolute_path(args.images_labels)
      file_pattern = os.path.join(image_dir, 'part-*')

      ds = tf.data.Dataset.list_files(file_pattern)
      ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
      if args.format == 'csv2':
        ds = ds.interleave(tf.data.TextLineDataset, cycle_length=args.readers, block_length=1)
        parse_fn = _parse_csv
      else:  # args.format == 'tfr'
        ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=args.readers, block_length=1)
        parse_fn = _parse_tfr
      ds = ds.map(parse_fn).batch(args.batch_size)
      iterator = ds.make_one_shot_iterator()
      x, y_ = iterator.get_next()

      # Build core model
      y, prediction = build_model(tf.get_default_graph(), x)

      # Add training bits
      x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
      tf.summary.image("x_img", x_img)

      global_step = tf.train.get_or_create_global_step()

      loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
      tf.summary.scalar("loss", loss)
      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)

      label = tf.argmax(y_, 1, name="label")
      correct_prediction = tf.equal(prediction, label)
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
      tf.summary.scalar("acc", accuracy)

      saver = tf.train.Saver()
      summary_op = tf.summary.merge_all()
      init_op = tf.global_variables_initializer()

    # Create a "supervisor", which oversees the training process and stores model state into HDFS
    model_dir = ctx.absolute_path(args.model)
    export_dir = ctx.absolute_path(args.export)
    print("tensorflow model path: {0}".format(model_dir))
    print("tensorflow export path: {0}".format(export_dir))
    summary_writer = tf.summary.FileWriter("tensorboard_%d" % worker_num, graph=tf.get_default_graph())

    if args.mode == 'inference':
      output_dir = ctx.absolute_path(args.output)
      print("output_dir: {}".format(output_dir))
      tf.gfile.MkDir(output_dir)
      output_file = tf.gfile.Open("{}/part-{:05d}".format(output_dir, task_index), mode='w')

    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(task_index == 0),
                                           scaffold=tf.train.Scaffold(init_op=init_op, summary_op=summary_op, saver=saver),
                                           checkpoint_dir=model_dir,
                                           hooks=[tf.train.StopAtStepHook(last_step=args.steps)]) as sess:
      print("{} session ready".format(datetime.now().isoformat()))

      # Loop until the session shuts down
      step = 0
      count = 0
      while not sess.should_stop():

        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.

        if args.mode == "train":
          if (step % 100 == 0):
            print("{} step: {} accuracy: {}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
          _, summary, step = sess.run([train_op, summary_op, global_step])
          if task_index == 0:
            summary_writer.add_summary(summary, step)
        else:  # args.mode == "inference"
          labels, pred, acc = sess.run([label, prediction, accuracy])
          # print("label: {0}, pred: {1}".format(labels, pred))
          print("acc: {}".format(acc))
          for i in range(len(labels)):
            count += 1
            output_file.write("{} {}\n".format(labels[i], pred[i]))
          print("count: {}".format(count))

    if args.mode == 'inference':
      output_file.close()

    print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))

    # export model (on chief worker only)
    if args.mode == "train" and task_index == 0:
      tf.reset_default_graph()

      # add placeholders for input images (and optional labels)
      x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name='x')
      y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
      label = tf.argmax(y_, 1, name="label")

      # add core model
      y, prediction = build_model(tf.get_default_graph(), x)

      # restore from last checkpoint
      saver = tf.train.Saver()
      with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(model_dir)
        print("ckpt: {}".format(ckpt))
        assert ckpt, "Invalid model checkpoint path: {}".format(model_dir)
        saver.restore(sess, ckpt.model_checkpoint_path)

        print("Exporting saved_model to: {}".format(export_dir))
        # exported signatures defined in code
        signatures = {
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
            'inputs': { 'image': x },
            'outputs': { 'prediction': prediction },
            'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          }
        }
        TFNode.export_saved_model(sess,
                                  export_dir,
                                  tf.saved_model.tag_constants.SERVING,
                                  signatures)
        print("Exported saved_model")

    # WORKAROUND for https://github.com/tensorflow/tensorflow/issues/21745
    # wait for all other nodes to complete (via done files)
    done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode)
    print("Writing done file to: {}".format(done_dir))
    tf.gfile.MakeDirs(done_dir)
    with tf.gfile.GFile("{}/{}".format(done_dir, ctx.task_index), 'w') as done_file:
      done_file.write("done")

    for i in range(60):
      if len(tf.gfile.ListDirectory(done_dir)) < len(ctx.cluster_spec['worker']):
        print("{} Waiting for other nodes {}".format(datetime.now().isoformat(), i))
        time.sleep(1)
      else:
        print("{} All nodes done".format(datetime.now().isoformat()))
        break