Пример #1
0
def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.eval_dir):
    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
  tf.gfile.MakeDirs(FLAGS.eval_dir)
  with tf.device('/cpu:0'):
      evaluate()
Пример #2
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if not gfile.Exists(FLAGS.train_dir):

        # gfile.DeleteRecursively(FLAGS.train_dir)
        gfile.MakeDirs(FLAGS.train_dir)
    train()
Пример #3
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    #Chen edit
    train()
Пример #4
0
def main(argv=None):  # pylint: disable=unused-argument
    print('Entered main')
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    print('Before train')
    train()
Пример #5
0
def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
#  if tf.gfile.Exists(FLAGS.eval_dir):
#    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
#  tf.gfile.MakeDirs(FLAGS.eval_dir)
#  evaluate()

  if gfile.Exists(FLAGS.eval_dir):
    gfile.DeleteRecursively(FLAGS.eval_dir)
  gfile.MakeDirs(FLAGS.eval_dir)
  evaluate()
Пример #6
0
def test_read_cifar10():
    from tensorflow.models.image.cifar10 import cifar10
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string(
        'my_train_dir', '../cifar10_model/model1',
        """Directory where to write event logs """
        """and checkpoint.""")
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.my_train_dir):
        tf.gfile.DeleteRecursively(FLAGS.my_train_dir)
    tf.gfile.MakeDirs(FLAGS.my_train_dir)
    with tf.Session() as sess:
        images, labels = cifar10.distorted_inputs()
        sess.run(tf.initialize_all_variables())
        a, b = sess.run([images, labels])
        print(len(a), len(a[0]))
def main(argv=None):  # pylint: disable=unused-argument
    total_time = time.time()
    data_load_time = time.time()
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    data_load_time = time.time() - data_load_time
    train_time = time.time()
    train()
    train_time = time.time() - train_time
    test_time = time.time()
    cifar10_eval.evaluate()
    test_time = time.time() - test_time
    total_time = time.time() - total_time

    print_time('Data load', data_load_time)
    print_time('Train', train_time)
    print_time('Test', test_time)
    print_time('Total', total_time)
def main(argv=None):  # pylint: disable=unused-argument
    total_time = time.time()
    data_load_time = time.time()
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    data_load_time = time.time() - data_load_time
    train_time = time.time()
    train()
    train_time = time.time() - train_time
    test_time = time.time()
    cifar10_eval.evaluate()
    test_time = time.time() - test_time
    total_time = time.time() - total_time

    print_time('Data load', data_load_time)
    print_time('Train', train_time)
    print_time('Test', test_time)
    print_time('Total', total_time)
Пример #9
0
def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.train_dir):
    tf.gfile.DeleteRecursively(FLAGS.train_dir)
  tf.gfile.MakeDirs(FLAGS.train_dir)

  if tf.gfile.Exists(FLAGS.train_dir0):
    tf.gfile.DeleteRecursively(FLAGS.train_dir0)
  tf.gfile.MakeDirs(FLAGS.train_dir0)

  if tf.gfile.Exists(FLAGS.train_dir1):
    tf.gfile.DeleteRecursively(FLAGS.train_dir1)
  tf.gfile.MakeDirs(FLAGS.train_dir1)

  if tf.gfile.Exists(FLAGS.train_dir2):
    tf.gfile.DeleteRecursively(FLAGS.train_dir2)
  tf.gfile.MakeDirs(FLAGS.train_dir2)

  if tf.gfile.Exists(FLAGS.train_dir3):
    tf.gfile.DeleteRecursively(FLAGS.train_dir3)
  tf.gfile.MakeDirs(FLAGS.train_dir3)

  if tf.gfile.Exists(FLAGS.train_dir4):
    tf.gfile.DeleteRecursively(FLAGS.train_dir4)
  tf.gfile.MakeDirs(FLAGS.train_dir4)

  if tf.gfile.Exists(FLAGS.train_dir5):
    tf.gfile.DeleteRecursively(FLAGS.train_dir5)
  tf.gfile.MakeDirs(FLAGS.train_dir5)

  if tf.gfile.Exists(FLAGS.train_dir6):
    tf.gfile.DeleteRecursively(FLAGS.train_dir6)
  tf.gfile.MakeDirs(FLAGS.train_dir6)

  if tf.gfile.Exists(FLAGS.train_dir7):
    tf.gfile.DeleteRecursively(FLAGS.train_dir7)
  tf.gfile.MakeDirs(FLAGS.train_dir7)

  if tf.gfile.Exists(FLAGS.train_dir8):
    tf.gfile.DeleteRecursively(FLAGS.train_dir8)
  tf.gfile.MakeDirs(FLAGS.train_dir8)

  if tf.gfile.Exists(FLAGS.train_dir9):
    tf.gfile.DeleteRecursively(FLAGS.train_dir9)
  tf.gfile.MakeDirs(FLAGS.train_dir9)

  if tf.gfile.Exists(FLAGS.train_dir10):
    tf.gfile.DeleteRecursively(FLAGS.train_dir10)
  tf.gfile.MakeDirs(FLAGS.train_dir10)

  if tf.gfile.Exists(FLAGS.train_dir11):
    tf.gfile.DeleteRecursively(FLAGS.train_dir11)
  tf.gfile.MakeDirs(FLAGS.train_dir11)

  if tf.gfile.Exists(FLAGS.train_dir12):
    tf.gfile.DeleteRecursively(FLAGS.train_dir12)
  tf.gfile.MakeDirs(FLAGS.train_dir12)

  if tf.gfile.Exists(FLAGS.train_dir13):
    tf.gfile.DeleteRecursively(FLAGS.train_dir13)
  tf.gfile.MakeDirs(FLAGS.train_dir13)

  if tf.gfile.Exists(FLAGS.train_dir14):
    tf.gfile.DeleteRecursively(FLAGS.train_dir14)
  tf.gfile.MakeDirs(FLAGS.train_dir14)

  if tf.gfile.Exists(FLAGS.train_dir15):
    tf.gfile.DeleteRecursively(FLAGS.train_dir15)
  tf.gfile.MakeDirs(FLAGS.train_dir15)

  if tf.gfile.Exists(FLAGS.train_dir16):
    tf.gfile.DeleteRecursively(FLAGS.train_dir16)
  tf.gfile.MakeDirs(FLAGS.train_dir16)

  if tf.gfile.Exists(FLAGS.train_dir17):
    tf.gfile.DeleteRecursively(FLAGS.train_dir17)
  tf.gfile.MakeDirs(FLAGS.train_dir17)

  if tf.gfile.Exists(FLAGS.train_dir18):
    tf.gfile.DeleteRecursively(FLAGS.train_dir18)
  tf.gfile.MakeDirs(FLAGS.train_dir18)

  if tf.gfile.Exists(FLAGS.train_dir19):
    tf.gfile.DeleteRecursively(FLAGS.train_dir19)
  tf.gfile.MakeDirs(FLAGS.train_dir19)
  train()
Пример #10
0
if __name__ == '__main__':
    import shutil
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-t',"--train", help="Run training phase", action='store_true')
    parser.add_argument('-e', "--eval", help="Run evaluation phase", action='store_true')
    args = parser.parse_args()

    # -- training phase --
    if args.train:
        if os.path.exists(train_dir):
            shutil.rmtree(train_dir)
        os.makedirs(train_dir)
        # download the data - in case.
        cifar10.maybe_download_and_extract();
        with tf.Graph().as_default():
            global_step = tf.Variable(0, trainable=False)
            # Get images and labels for CIFAR-10.
            images, labels = distorted_inputs()
            # Build a Graph that computes the logits predictions from the
            # inference model.
            logits = inference(images);
            # Calculate cross entropy loss.
            total_loss = loss(logits, labels)
            # Build a Graph that trains the model with one batch of examples and
            # updates the model parameters.
            train_op = train(total_loss, global_step)
            # Create a saver.
            saver = tf.train.Saver( tf.all_variables() )
            # Build the summary operation based on the TF collection of Summaries.
Пример #11
0
def main(argv=None):  # pylint: disable=unused-argument
    print("train directory is %s" % FLAGS.train_dir)
    cifar10.maybe_download_and_extract()
    train()
def main(argv=None):
    print("CUDA_VISIBLE_DEVICES=\"%s\"" % (os.getenv('CUDA_VISIBLE_DEVICES')))
    cifar10.maybe_download_and_extract()
    train()
def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  train()
def main(unused_argv):
  #mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
  cifar10.maybe_download_and_extract()
  if FLAGS.download_only:
    sys.exit(0)
  #cifar10.maybe_download_and_extract()
  if FLAGS.job_name is None or FLAGS.job_name == "":
    raise ValueError("Must specify an explicit `job_name`")
  if FLAGS.task_index is None or FLAGS.task_index =="":
    raise ValueError("Must specify an explicit `task_index`")

  print("job name = %s" % FLAGS.job_name)
  print("task index = %d" % FLAGS.task_index)

  #Construct the cluster and start the server
  ps_spec = FLAGS.ps_hosts.split(",")
  worker_spec = FLAGS.worker_hosts.split(",")

  # Get the number of workers.
  num_workers = len(worker_spec)

  cluster = tf.train.ClusterSpec({
      "ps": ps_spec,
      "worker": worker_spec})

  if not FLAGS.existing_servers:
    # Not using existing servers. Create an in-process server.
    server = tf.train.Server(
        cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
    if FLAGS.job_name == "ps":
      server.join()

  is_chief = (FLAGS.task_index == 0)
  if FLAGS.num_gpus > 0:
    if FLAGS.num_gpus < num_workers:
      raise ValueError("number of gpus is less than number of workers")
    # Avoid gpu allocation conflict: now allocate task_num -> #gpu 
    # for each worker in the corresponding machine
    gpu = (FLAGS.task_index % FLAGS.num_gpus)
    worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
  elif FLAGS.num_gpus == 0:
    # Just allocate the CPU to worker server
    cpu = 0
    worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
  # The device setter will automatically place Variables ops on separate
  # parameter servers (ps). The non-Variable ops will be placed on the workers.
  # The ps use CPU and workers use corresponding GPU
  with tf.device(
      tf.train.replica_device_setter(
          worker_device=worker_device,
          ps_device="/job:ps/cpu:0",
          cluster=cluster)):
    cifar10.maybe_download_and_extract()
    global_step = tf.Variable(0, name="global_step", trainable=False)

    # # Variables of the hidden layer
    # hid_w = tf.Variable(
    #     tf.truncated_normal(
    #         [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
    #         stddev=1.0 / IMAGE_PIXELS),
    #     name="hid_w")
    # hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")

    # # Variables of the softmax layer
    # sm_w = tf.Variable(
    #     tf.truncated_normal(
    #         [FLAGS.hidden_units, 10],
    #         stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
    #     name="sm_w")
    # sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

    # # Ops: located on the worker specified with FLAGS.task_index
    # x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
    # y_ = tf.placeholder(tf.float32, [None, 10])

    # hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
    # hid = tf.nn.relu(hid_lin)

    # y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
    # cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    #train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries();
    # Variables that affect learning rate.
    num_batches_per_epoch = 50000 / FLAGS.batch_size
    decay_steps = int(num_batches_per_epoch * 350)

    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(0.1,
                                    global_step,
                                    decay_steps,
                                    0.1,
                                    staircase=True)

    # Generate moving averages of all losses and associated summaries.
    #loss_averages_op = _add_loss_summaries(total_loss)

    opt = tf.train.GradientDescentOptimizer(lr)
    
    #opt = tf.train.AdamOptimizer(FLAGS.learning_rate)

    if FLAGS.sync_replicas:
      if FLAGS.replicas_to_aggregate is None:
        replicas_to_aggregate = num_workers
      else:
        replicas_to_aggregate = FLAGS.replicas_to_aggregate

      opt = tf.train.SyncReplicasOptimizerV2(
          opt,
          replicas_to_aggregate=replicas_to_aggregate,
          total_num_replicas=num_workers,
          name="cifar10_sync_replicas")

    train_step = opt.minimize(loss, global_step=global_step)

    if FLAGS.sync_replicas:
      local_init_op = opt.local_step_init_op
      if is_chief:
        local_init_op = opt.chief_init_op

      ready_for_local_init_op = opt.ready_for_local_init_op

      # Initial token and chief queue runners required by the sync_replicas mode
      chief_queue_runner = opt.get_chief_queue_runner()
      sync_init_op = opt.get_init_tokens_op()

    init_op = tf.global_variables_initializer()
    train_dir = tempfile.mkdtemp(dir="/mnt")

    if FLAGS.sync_replicas:
      sv = tf.train.Supervisor(
          is_chief=is_chief,
          logdir=train_dir,
          init_op=init_op,
          local_init_op=local_init_op,
          ready_for_local_init_op=ready_for_local_init_op,
          recovery_wait_secs=1,
          global_step=global_step)
    else:
      sv = tf.train.Supervisor(
          is_chief=is_chief,
          logdir=train_dir,
          init_op=init_op,
          recovery_wait_secs=1,
          global_step=global_step)

    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])

    # The chief worker (task_index==0) session will prepare the session,
    # while the remaining workers will wait for the preparation to complete.
    if is_chief:
      print("Worker %d: Initializing session..." % FLAGS.task_index)
    else:
      print("Worker %d: Waiting for session to be initialized..." %
            FLAGS.task_index)

    if FLAGS.existing_servers:
      server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
      print("Using existing server at: %s" % server_grpc_url)

      sess = sv.prepare_or_wait_for_session(server_grpc_url,
                                            config=sess_config)
    else:
      sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)

    print("Worker %d: Session initialization complete." % FLAGS.task_index)

    if FLAGS.sync_replicas and is_chief:
      # Chief worker will start the chief queue runner and call the init op.
      sess.run(sync_init_op)
      sv.start_queue_runners(sess, [chief_queue_runner])

    # Perform training
    time_begin = time.time()
    print("Training begins @ %f" % time_begin)

    local_step = 0
    while True:
      start_time = time.time()
      _, step = sess.run([train_step, global_step])
      duration = time.time() - start_time
      local_step += 1
      #assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      #if step % 10 == 0:
      #  num_examples_per_step = FLAGS.batch_size
      #  examples_per_sec = num_examples_per_step / duration
      #  sec_per_batch = float(duration)
#	loss_value = 0
 #       format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
#                      'sec/batch)')
#        print (format_str % (datetime.now(), local_step, loss_value,
#                             examples_per_sec, sec_per_batch))
      now = time.time()
      print("%f: Worker %d: training step %d done (global step: %d)" % (now, FLAGS.task_index, local_step, step))

      if step >= FLAGS.train_steps:
        break

      #if step % 100 == 0:
      #  summary_str = sess.run(summary_op)
      #  summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      #if step % 1000 == 0 or (step + 1) == FLAGS.train_steps:
      #  checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
      #  saver.save(sess, checkpoint_path, global_step=step)
    # local_step = 0
    # while True:
    #   # Training feed
    #   batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
    #   train_feed = {x: batch_xs, y_: batch_ys}

    #   _, step = sess.run([train_step, global_step], feed_dict=train_feed)
    #   local_step += 1

    #   now = time.time()
    #   print("%f: Worker %d: training step %d done (global step: %d)" %
    #         (now, FLAGS.task_index, local_step, step))

    #   if step >= FLAGS.train_steps:
    #     break

    time_end = time.time()
    print("Training ends @ %f" % time_end)
    training_time = time_end - time_begin
    print("Training elapsed time: %f s" % training_time)
Пример #15
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    if tf.gfile.Exists(FLAGS.train_dir0):
        tf.gfile.DeleteRecursively(FLAGS.train_dir0)
    tf.gfile.MakeDirs(FLAGS.train_dir0)

    if tf.gfile.Exists(FLAGS.train_dir1):
        tf.gfile.DeleteRecursively(FLAGS.train_dir1)
    tf.gfile.MakeDirs(FLAGS.train_dir1)

    if tf.gfile.Exists(FLAGS.train_dir2):
        tf.gfile.DeleteRecursively(FLAGS.train_dir2)
    tf.gfile.MakeDirs(FLAGS.train_dir2)

    if tf.gfile.Exists(FLAGS.train_dir3):
        tf.gfile.DeleteRecursively(FLAGS.train_dir3)
    tf.gfile.MakeDirs(FLAGS.train_dir3)

    if tf.gfile.Exists(FLAGS.train_dir4):
        tf.gfile.DeleteRecursively(FLAGS.train_dir4)
    tf.gfile.MakeDirs(FLAGS.train_dir4)

    if tf.gfile.Exists(FLAGS.train_dir5):
        tf.gfile.DeleteRecursively(FLAGS.train_dir5)
    tf.gfile.MakeDirs(FLAGS.train_dir5)

    if tf.gfile.Exists(FLAGS.train_dir6):
        tf.gfile.DeleteRecursively(FLAGS.train_dir6)
    tf.gfile.MakeDirs(FLAGS.train_dir6)

    if tf.gfile.Exists(FLAGS.train_dir7):
        tf.gfile.DeleteRecursively(FLAGS.train_dir7)
    tf.gfile.MakeDirs(FLAGS.train_dir7)

    if tf.gfile.Exists(FLAGS.train_dir8):
        tf.gfile.DeleteRecursively(FLAGS.train_dir8)
    tf.gfile.MakeDirs(FLAGS.train_dir8)

    if tf.gfile.Exists(FLAGS.train_dir9):
        tf.gfile.DeleteRecursively(FLAGS.train_dir9)
    tf.gfile.MakeDirs(FLAGS.train_dir9)

    if tf.gfile.Exists(FLAGS.train_dir10):
        tf.gfile.DeleteRecursively(FLAGS.train_dir10)
    tf.gfile.MakeDirs(FLAGS.train_dir10)

    if tf.gfile.Exists(FLAGS.train_dir11):
        tf.gfile.DeleteRecursively(FLAGS.train_dir11)
    tf.gfile.MakeDirs(FLAGS.train_dir11)

    if tf.gfile.Exists(FLAGS.train_dir12):
        tf.gfile.DeleteRecursively(FLAGS.train_dir12)
    tf.gfile.MakeDirs(FLAGS.train_dir12)

    if tf.gfile.Exists(FLAGS.train_dir13):
        tf.gfile.DeleteRecursively(FLAGS.train_dir13)
    tf.gfile.MakeDirs(FLAGS.train_dir13)

    if tf.gfile.Exists(FLAGS.train_dir14):
        tf.gfile.DeleteRecursively(FLAGS.train_dir14)
    tf.gfile.MakeDirs(FLAGS.train_dir14)

    if tf.gfile.Exists(FLAGS.train_dir15):
        tf.gfile.DeleteRecursively(FLAGS.train_dir15)
    tf.gfile.MakeDirs(FLAGS.train_dir15)

    if tf.gfile.Exists(FLAGS.train_dir16):
        tf.gfile.DeleteRecursively(FLAGS.train_dir16)
    tf.gfile.MakeDirs(FLAGS.train_dir16)

    if tf.gfile.Exists(FLAGS.train_dir17):
        tf.gfile.DeleteRecursively(FLAGS.train_dir17)
    tf.gfile.MakeDirs(FLAGS.train_dir17)

    if tf.gfile.Exists(FLAGS.train_dir18):
        tf.gfile.DeleteRecursively(FLAGS.train_dir18)
    tf.gfile.MakeDirs(FLAGS.train_dir18)

    if tf.gfile.Exists(FLAGS.train_dir19):
        tf.gfile.DeleteRecursively(FLAGS.train_dir19)
    tf.gfile.MakeDirs(FLAGS.train_dir19)
    train()
def main(argv=None):
    cifar10.maybe_download_and_extract()
    train()
Пример #17
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if gfile.Exists(FLAGS.eval_dir):
        gfile.DeleteRecursively(FLAGS.eval_dir)
    gfile.MakeDirs(FLAGS.eval_dir)
    evaluate()
Пример #18
0
def main(argv=None):
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.DeleteRecursively(FLAGS.eval_dir)
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    evaluate()
def main(unused_argv):
    cifar10.maybe_download_and_extract()
    if FLAGS.download_only:
        sys.exit(0)
    if FLAGS.job_name is None or FLAGS.job_name == "":
        raise ValueError("Must specify an explicit `job_name`")
    if FLAGS.task_index is None or FLAGS.task_index == "":
        raise ValueError("Must specify an explicit `task_index`")

    print("job name = %s" % FLAGS.job_name)
    print("task index = %d" % FLAGS.task_index)

    #Construct the cluster and start the server
    ps_spec = FLAGS.ps_hosts.split(",")
    worker_spec = FLAGS.worker_hosts.split(",")

    #Approximation Layers
    approx_layers = FLAGS.layers_to_train.split(",")
    len_approx_layers = len(approx_layers)

    # Get the number of workers.
    num_workers = len(worker_spec)
    num_ps = len(ps_spec)

    cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})

    if not FLAGS.existing_servers:
        # Not using existing servers. Create an in-process server.
        server = tf.train.Server(cluster,
                                 job_name=FLAGS.job_name,
                                 task_index=FLAGS.task_index)
        if FLAGS.job_name == "ps":
            server.join()

    is_chief = (FLAGS.task_index == 0)
    if FLAGS.num_gpus > 0:
        if FLAGS.num_gpus < num_workers:
            raise ValueError("number of gpus is less than number of workers")
        # Avoid gpu allocation conflict: now allocate task_num -> #gpu
        # for each worker in the corresponding machine
        gpu = (FLAGS.task_index % FLAGS.num_gpus)
        worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
    elif FLAGS.num_gpus == 0:
        # Just allocate the CPU to worker server
        cpu = 0
        worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
    # The device setter will automatically place Variables ops on separate
    # parameter servers (ps). The non-Variable ops will be placed on the workers.
    # The ps use CPU and workers use corresponding GPU
    with tf.device(
            tf.train.replica_device_setter(
                worker_device=worker_device,
                ps_device="/job:ps/cpu:0",
                cluster=cluster,
                ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
                    num_ps, tf.contrib.training.byte_size_load_fn))):
        global_step = tf.Variable(0, name="global_step", trainable=False)
        #variables_to_update = tf.Placeholder(, name="variables_to_update")

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        #train_op = cifar10.train(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        # Variables that affect learning rate.
        num_batches_per_epoch = 50000 / FLAGS.batch_size
        decay_steps = int(num_batches_per_epoch * 350)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(0.1,
                                        global_step,
                                        decay_steps,
                                        0.1,
                                        staircase=True)

        opt = tf.train.GradientDescentOptimizer(lr)

        if FLAGS.sync_replicas:
            if FLAGS.replicas_to_aggregate is None:
                replicas_to_aggregate = num_workers
            else:
                replicas_to_aggregate = FLAGS.replicas_to_aggregate

            opt = tf.train.SyncReplicasOptimizerV2(
                opt,
                replicas_to_aggregate=replicas_to_aggregate,
                total_num_replicas=num_workers,
                name="cifar10_sync_replicas")

        #trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        train_step = opt.minimize(loss, global_step=global_step)

        # Approximation Training
        var_list = []
        for i in range(len_approx_layers):
            var_list = var_list + tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=approx_layers[i])

        train_step_approx = opt.minimize(loss,
                                         global_step=global_step,
                                         var_list=var_list)

        if FLAGS.sync_replicas:
            local_init_op = opt.local_step_init_op
            if is_chief:
                local_init_op = opt.chief_init_op

            ready_for_local_init_op = opt.ready_for_local_init_op

            # Initial token and chief queue runners required by the sync_replicas mode
            chief_queue_runner = opt.get_chief_queue_runner()
            sync_init_op = opt.get_init_tokens_op()

        init_op = tf.global_variables_initializer()
        train_dir = tempfile.mkdtemp(dir="/mnt",
                                     suffix="data",
                                     prefix="cifar10_train")

        if FLAGS.sync_replicas:
            sv = tf.train.Supervisor(
                is_chief=is_chief,
                logdir=train_dir,
                init_op=init_op,
                local_init_op=local_init_op,
                saver=None,
                summary_op=summary_op,
                save_summaries_secs=120,
                save_model_secs=600,
                checkpoint_basename='model.ckpt',
                ready_for_local_init_op=ready_for_local_init_op,
                recovery_wait_secs=1,
                global_step=global_step)
        else:
            sv = tf.train.Supervisor(is_chief=is_chief,
                                     logdir=train_dir,
                                     init_op=init_op,
                                     saver=None,
                                     summary_op=summary_op,
                                     save_summaries_secs=120,
                                     save_model_secs=600,
                                     checkpoint_basename='model.ckpt',
                                     recovery_wait_secs=1,
                                     global_step=global_step)

        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     log_device_placement=False,
                                     device_filters=[
                                         "/job:ps",
                                         "/job:worker/task:%d" %
                                         FLAGS.task_index
                                     ])

        # The chief worker (task_index==0) session will prepare the session,
        # while the remaining workers will wait for the preparation to complete.
        if is_chief:
            print("Worker %d: Initializing session..." % FLAGS.task_index)
        else:
            print("Worker %d: Waiting for session to be initialized..." %
                  FLAGS.task_index)

        if FLAGS.existing_servers:
            server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
            print("Using existing server at: %s" % server_grpc_url)

            sess = sv.prepare_or_wait_for_session(server_grpc_url,
                                                  config=sess_config)
        else:
            sess = sv.prepare_or_wait_for_session(server.target,
                                                  config=sess_config)

        print("Worker %d: Session initialization complete." % FLAGS.task_index)

        if FLAGS.sync_replicas and is_chief:
            # Chief worker will start the chief queue runner and call the init op.
            sess.run(sync_init_op)
            sv.start_queue_runners(sess, [chief_queue_runner])

        # Restore from Checkpoint
        if FLAGS.checkpoint_restore > 0:
            checkpoint_directory = FLAGS.checkpoint_dir + str(
                FLAGS.checkpoint_restore)
            ckpt = tf.train.get_checkpoint_state(checkpoint_directory)
            if ckpt and ckpt.model_checkpoint_path:
                # Restores from checkpoint
                saver.restore(sess, ckpt.model_checkpoint_path)
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                #global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            else:
                print('No checkpoint file found')
                return

        # Perform training
        time_begin = time.time()
        print("Training begins @ %f" % time_begin)

        local_step = 0
        num_examples_per_step = 128
        f = open('/mnt/train_output.log', 'w')
        #f.write("Training begins @ " + str(time_begin) +"\n")
        f.write(
            "Duration\tWorker\tLocalStep\tGlobalStep\tLoss\tExamplesPerSec\n")
        f.close()
        last = time_begin
        while True:
            start_time = time.time()
            if local_step < FLAGS.approx_step:
                _, step, loss_value = sess.run([train_step, global_step, loss])
            else:
                if local_step % FLAGS.approx_interval == 0:
                    _, step, loss_value = sess.run(
                        [train_step_approx, global_step, loss])
                else:
                    _, step, loss_value = sess.run(
                        [train_step, global_step, loss])

            duration = time.time() - start_time
            local_step += 1
            if local_step % 10 == 0:
                now = time.time()
                examples_per_sec = 10 * num_examples_per_step / (now - last)
                print(
                    "%f: Worker %d: step %d (global step: %d of %d) loss = %.2f examples_per_sec = %.2f \n"
                    % (now - last, FLAGS.task_index, local_step, step,
                       FLAGS.train_steps, loss_value, examples_per_sec))
                f = open('/mnt/train_output.log', 'a')
                f.write(
                    str(now - last) + "\t" + str(FLAGS.task_index) + "\t" +
                    str(local_step) + "\t" + str(step) + "\t" +
                    str(loss_value) + "\t" + str(examples_per_sec) + "\n")
                f.close()
                last = now

            if step >= FLAGS.train_steps:
                break

            if sv.should_stop():
                print('Stopped due to abort')
                break
            # Save the model checkpoint periodically.
            #if is_chief and (step % 1000 == 0 or (step + 1) == FLAGS.train_steps):
            if (step % 1000 == 0 or (step + 1) == FLAGS.train_steps):
                print('Taking a Checkpoint @ Global Step ' + str(step))
                checkpoint_dir = "/mnt/checkpoint" + str(step)
                if tf.gfile.Exists(checkpoint_dir):
                    tf.gfile.DeleteRecursively(checkpoint_dir)
                tf.gfile.MakeDirs(checkpoint_dir)
                checkpoint_path = os.path.join(checkpoint_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)

        time_end = time.time()
        print("Training ends @ %f" % time_end)
        f = open('/mnt/train_output.log', 'a')
        #f.write("Training ends @ " + str(time_end) +"\n")
        training_time = time_end - time_begin
        print("Training elapsed time: %f s" % training_time)
        f.write("Training elapsed time: " + str(training_time) + " s\n")
        f.close()
Пример #20
0
def main(argv=None):  # pylint: disable=unused-argument
    logging.basicConfig(level=logging.INFO)

    cifar10.maybe_download_and_extract()
    visualize_excitations()
def main(argv=None):
    cifar10.maybe_download_and_extract()
    train()