def evaluate():
    """Eval CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        # 导入cifar100中的测试数据
        eval_data = FLAGS.eval_data == 'test'
        images, labels = cifar100.single_inputs(eval_data=eval_data)

        # 通过卷积神经网络得到结果,此处为构建模型,并未真正计算
        logits = cifar100.inference(images)

        # 每个logits元素前k个最大值是否包含labels的正确结果
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # 恢复学习完成后的滑动平均变量
        variable_averages = tf.train.ExponentialMovingAverage(
            cifar100.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        # 将保存的影子变量值直接赋予当前变量,等待注释
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        while True:
            eval_once(saver, summary_writer, top_k_op, summary_op, logits)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
def evaluate_one():
    """导出一张图片的标签"""
    with tf.Graph().as_default() as g:
        eval_data = FLAGS.eval_data == 'test'
        images, labels = cifar100.inputs(eval_data=eval_data)

    # 通过卷积神经网络得到结果
    logits = cifar100.inference(images)
    # 恢复学习完成后的滑动平均变量
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar100.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    # 将保存的影子变量值直接赋予当前变量,等待注释
    saver = tf.train.Saver(variables_to_restore)
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # 加载检查点数据
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("load successfully")
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            # 加载未成功
            print('No checkpoint file found')
            return
        final_label = sess.run(logits)
        print(final_label)
def evaluate():
  """Eval CIFAR-100 for a number of steps."""
  with tf.Graph().as_default() as g:
    # Get images and labels for CIFAR-100.
    eval_data = FLAGS.eval_data == 'test'
    images, labels = cifar100.inputs(eval_data=eval_data)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar100.inference(images)

    # Calculate predictions.
    top_k_op = tf.nn.in_top_k(logits, labels, 1)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar100.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

    while True:
      eval_once(saver, summary_writer, top_k_op, summary_op)
      if FLAGS.run_once:
        break
      time.sleep(FLAGS.eval_interval_secs)
def tower_loss(scope, images, labels):
    """Calculate the total loss on a single tower running the CIFAR model.

  Args:
    scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
    images: Images. 4D tensor of shape [batch_size, height, width, 3].
    labels: Labels. 1D tensor of shape [batch_size].

  Returns:
     Tensor of shape [] containing the total loss for a batch of data
  """

    # Build inference Graph.
    logits = cifar100.inference(images)

    # Build the portion of the Graph calculating the losses. Note that we will
    # assemble the total_loss using a custom function below.
    _ = cifar100.loss(logits, labels)

    # Assemble all of the losses for the current tower only.
    losses = tf.get_collection('losses', scope)

    # Calculate the total loss for the current tower.
    total_loss = tf.add_n(losses, name='total_loss')

    # Attach a scalar summary to all individual losses and the total loss; do the
    # same for the averaged version of the losses.
    for l in losses + [total_loss]:
        # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
        # session. This helps the clarity of presentation on tensorboard.
        loss_name = re.sub('%s_[0-9]*/' % cifar100.TOWER_NAME, '', l.op.name)
        tf.summary.scalar(loss_name, l)

    return total_loss
示例#5
0
def evaluate():
    with tf.Graph().as_default() as g:
        global_step = tf.Variable(0, trainable=False)
        images, labels = cifar100.eval_inputs()
        logits = cifar100.inference(images)
        top_k_op = tf.nn.in_top_k(logits, labels, 3)
        saver = tf.train.Saver(tf.all_variables())
        while True:
            eval_once(saver, top_k_op)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
示例#6
0
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar100.distorted_inputs()

        logits = cifar100.inference(images)

        loss = cifar100.loss(logits, labels)

        train_op = cifar100.train(loss, global_step)

        saver = tf.train.Saver(tf.all_variables())

        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value)

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

                # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
示例#7
0
def train():
    """Train CIFAR-100 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-100.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar100.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar100.inference(images)

        # Calculate loss.
        loss = cifar100.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar100.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    self._last_loss = loss_value
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

            def last_loss(self):
                return self._last_loss

        loghook = _LoggerHook()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss), loghook
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:

            t1 = time.time()
            while not mon_sess.should_stop():
                mon_sess.run(train_op)

            t2 = time.time()
            print('spent %f seconds to train %d step' %
                  (t2 - t1, FLAGS.max_steps))
            print('spent %f seconds to train %d step' %
                  (t2 - t1, FLAGS.max_steps))
            print('last loss value: %.2f ' % loghook.last_loss())
def train():
    """训练cifar100"""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            images, labels = cifar100.destorted_inputs()

        # 建立模型,并获取得到的结果logits,用于与labels求交叉熵
        logits = cifar100.inference(images)

        # 计算损失
        loss = cifar100.loss(logits, labels)

        train_op = cifar100.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """打印损失和运行状态"""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')

                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allocator_type = 'BFC'  # 使用BFC算法
        config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 程序最多只能占用指定gpu50%的显存
        config.gpu_options.allow_growth = True  # 程序按需申请内存
        # MonitoredTrainingSession是一个方便的tensorflow会话初始化/恢复器,
        # 也可用于分布式训练
        with tf.train.MonitoredTrainingSession(
                # 加载保存的训练状态的目录,如为空则设为保存目录
                checkpoint_dir=FLAGS.train_dir,
                # 保存间隔
                save_checkpoint_secs=None,
                save_checkpoint_steps=10000,
                # 可选的SessionRunHook对象列表
                # StopAtStepHook表示停止步数
                # NanTensorHook表示当loss为None时返回异常并停止训练
                hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                       tf.train.NanTensorHook(loss),
                       _LoggerHook()],
                config=config) as mon_sess:

            while not mon_sess.should_stop():
                mon_sess.run(train_op)
示例#9
0
def train():
    """Train CIFAR-10 for a number of steps."""

    #lanhin
    #Construct the cluster and start the server
    ps_spec = FLAGS.ps_hosts.split(",")
    worker_spec = FLAGS.worker_hosts.split(",")

    # Get the number of workers.
    num_workers = len(worker_spec)

    cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)
    if FLAGS.job_name == "ps":
        server.join()
    # only worker will do train()
    is_chief = False
    if FLAGS.task_index == 0:
        is_chief = True

    #lanhin end

    #with tf.Graph().as_default():

    # Use comment to choose which way of tf.device() you want to use
    #with tf.Graph().as_default(), tf.device(tf.train.replica_device_setter(
    #    worker_device="/job:worker/task:%d" % FLAGS.task_index,
    #    cluster=cluster)):
    with tf.device("job:worker/task:%d" % FLAGS.task_index):
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        #with tf.device('/cpu:0'):
        #images, labels = cifar10.distorted_inputs()

        (x_train, y_train_orl), (x_test, y_test_orl) = dset.cifar100.load_data(
            label_mode='fine')
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train, x_test = normalize(x_train, x_test)

        y_train_orl = y_train_orl.astype('int32')
        y_test_orl = y_test_orl.astype('int32')
        y_train_flt = y_train_orl.ravel()
        y_test_flt = y_test_orl.ravel()

        x = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, 32, 32, 3))
        y = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, ))

        # Build a Graph that computes the logits predictions from the
        # inference model.
        #logits, local_var_list = cifar10.inference(images)
        logits, local_var_list = cifar100.inference(x)

        # Calculate loss.
        #loss = cifar10.loss(logits, labels)
        loss = cifar100.loss(logits, y)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar100.train(loss, global_step)

        # the temp var part, for performance testing
        tmp_var_list = []
        var_index = 0
        for var in local_var_list:
            var_index += 1
            tmp_var_list.append(
                tf.Variable(tf.zeros(var.shape),
                            name="tmp_var" + str(var_index)))

        # the non chief workers get local var init_op here
        if not is_chief:
            init_op = tf.global_variables_initializer()
        else:
            init_op = None

        # start global variables region
        global_var_list = []
        with tf.device("/job:ps/replica:0/task:0/cpu:0"):
            # barrier var
            finished = tf.get_variable("worker_finished", [],
                                       tf.int32,
                                       tf.zeros_initializer(tf.int32),
                                       trainable=False)
            with finished.graph.colocate_with(finished):
                finish_op = finished.assign_add(1, use_locking=True)

            var_index = 0
            for var in local_var_list:
                var_index += 1
                global_var_list.append(
                    tf.Variable(tf.zeros(var.shape),
                                name="glo_var" + str(var_index)))

        def assign_global_vars():  # assign local vars' values to global vars
            return [
                gvar.assign(lvar)
                for (gvar, lvar) in zip(global_var_list, local_var_list)
            ]

        def assign_local_vars():  # assign global vars' values to local vars
            return [
                lvar.assign(gvar)
                for (gvar, lvar) in zip(global_var_list, local_var_list)
            ]

        def assign_tmp_vars():  # assign local vars' values to tmp vars
            return [
                tvar.assign(lvar)
                for (tvar, lvar) in zip(tmp_var_list, local_var_list)
            ]

        def assign_local_vars_from_tmp(
        ):  # assign tmp vars' values to local vars
            return [
                lvar.assign(tvar)
                for (tvar, lvar) in zip(tmp_var_list, local_var_list)
            ]

        def update_before_train(alpha, w, global_w):
            varib = alpha * (w - global_w)
            gvar_op = global_w.assign(global_w + varib)
            return gvar_op, varib

        def update_after_train(w, vab):
            return w.assign(w - vab)

        assign_list_local = assign_local_vars()
        assign_list_global = assign_global_vars()
        assign_list_loc2tmp = assign_tmp_vars()
        assign_list_tmp2loc = assign_local_vars_from_tmp()

        before_op_tuple_list = []
        after_op_tuple_list = []
        vbholder_list = []
        for (gvar, lvar) in zip(global_var_list, local_var_list):
            before_op_tuple_list.append(
                (update_before_train(alpha, lvar, gvar)))
        for var in local_var_list:
            vbholder_list.append(tf.placeholder("float", var.shape))
            after_op_tuple_list.append(
                (update_after_train(var,
                                    vbholder_list[-1]), vbholder_list[-1]))

        # the chief worker get global var init op here
        if is_chief:
            init_op = tf.global_variables_initializer()

        # global variables region end

        #lanhin start
        sv = tf.train.Supervisor(
            is_chief=True,  #is_chief,
            logdir=FLAGS.train_dir,
            init_op=init_op,
            #local_init_op=loc_init_op,
            recovery_wait_secs=1)
        #global_step=global_step)

        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement,
            device_filters=[
                "/job:ps", "/job:worker/task:%d" % FLAGS.task_index
            ])

        # The chief worker (task_index==0) session will prepare the session,
        # while the remaining workers will wait for the preparation to complete.
        if is_chief:
            print("Worker %d: Initializing session..." % FLAGS.task_index)
        else:
            print("Worker %d: Waiting for session to be initialized..." %
                  FLAGS.task_index)

        sess = sv.prepare_or_wait_for_session(server.target,
                                              config=sess_config)

        if is_chief:
            sess.run(assign_list_global)
            barrier_finished = sess.run(finish_op)
            print("barrier_finished:", barrier_finished)
        else:
            barrier_finished = sess.run(finish_op)
            print("barrier_finished:", barrier_finished)
        while barrier_finished < num_workers:
            time.sleep(1)
            barrier_finished = sess.run(finished)
        sess.run(assign_list_local)
        print("Worker %d: Session initialization complete." % FLAGS.task_index)
        # lanhin end

        #sess = tf.Session()
        #sess.run(init_op)
        #tf.train.start_queue_runners(sess)
        f = open('tl_dist.json', 'w')
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        time_begin = time.time()
        #    while not mon_sess.should_stop():
        #      mon_sess.run(train_op)
        for step in range(FLAGS.max_steps):
            offset = (step * FLAGS.batch_size) % (EPOCH_SIZE -
                                                  FLAGS.batch_size)
            x_data = x_train[offset:(offset + FLAGS.batch_size), ...]
            y_data_flt = y_train_flt[offset:(offset + FLAGS.batch_size)]

            if step % FLAGS.log_frequency == 0:
                time_step = time.time()
                steps_time = time_step - time_begin
                print("step:", step, " steps time:", steps_time, end='  ')
                sess.run(assign_list_loc2tmp)
                sess.run(assign_list_local)
                predt(sess, x_test, y_test_flt, logits, x, y)
                sess.run(assign_list_tmp2loc)
                time_begin = time.time()
            if step % FLAGS.tau == 0 and step > 0:  # update global weights
                thevarib_list = []
                for i in range(0, len(before_op_tuple_list)):
                    (gvar_op, varib) = before_op_tuple_list[i]
                    _, thevarib = sess.run([gvar_op, varib])
                    thevarib_list.append(thevarib)

                sess.run(train_op, feed_dict={x: x_data, y: y_data_flt})

                for i in range(0, len(after_op_tuple_list)):
                    (lvar_op, thevaribHolder) = after_op_tuple_list[i]
                    sess.run(lvar_op,
                             feed_dict={thevaribHolder: thevarib_list[i]})

            else:
                sess.run(train_op, feed_dict={
                    x: x_data,
                    y: y_data_flt
                })  #, options=run_options, run_metadata=run_metadata)
                #tl = timeline.Timeline(run_metadata.step_stats)
                #ctf = tl.generate_chrome_trace_format()
        #f.write(ctf)
        time_end = time.time()
        training_time = time_end - time_begin
        print("Training elapsed time: %f s" % training_time)
        f.close()
        sess.run(assign_list_local)
        predt(sess, x_test, y_test_flt, logits, x, y)
示例#10
0
def train():
    """Train CIFAR-100 for a number of steps."""
    output = open('output_data/output_' + str(time.time()) + '.txt', 'w')
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-100.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
          images, labels = cifar100.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        
        logitsA,logitsB = cifar100.inference(images)

        # Calculate loss.
        lossA = cifar100.loss(logitsA, labels)
        lossB = cifar100.loss(logitsB, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_opA = cifar100.train(lossA, global_step)
        train_opB = cifar100.train(lossB, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(lossA)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))
                    print((str(self._step) + '\t' +
                           str(loss_value) + '\n'), file=output)

        with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(lossA),
                   tf.train.NanTensorHook(lossB),
                   _LoggerHook()],
            config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            
            file_writer = tf.summary.FileWriter('tb-logs/', mon_sess.graph)

            while not mon_sess.should_stop():
                print("stepA")
                mon_sess.run(train_opA)
                print("stepB")
                mon_sess.run(train_opB)
        output.close()
def train():
    print('FLAGS.data_dir: %s' % FLAGS.data_dir)
    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")
    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)
    if FLAGS.job_name == 'ps':
        server.join()
    is_chief = (FLAGS.task_index == 0)
    with tf.device(
            tf.train.replica_device_setter(
                worker_device="/job:worker/task:%d" % FLAGS.task_index,
                ps_device="/job:ps/task:0",
                cluster=cluster)):
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Get images and labels for CIFAR-100.
        images, labels = cifar100.distorted_inputs()
        num_workers = len(worker_hosts)
        num_replicas_to_aggregate = num_workers
        logits = cifar100.inference(images)
        # Calculate loss.
        loss = cifar100.loss(logits, labels)
        # Retain the summaries from the chief.
        # Calculate the learning rate schedule.
        num_batches_per_epoch = (cifar100.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                                 FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch *
                          cifar100.NUM_EPOCHS_PER_DECAY)
        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(cifar100.INITIAL_LEARNING_RATE,
                                        global_step,
                                        decay_steps,
                                        cifar100.LEARNING_RATE_DECAY_FACTOR,
                                        staircase=True)
        if is_chief:
            summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
            # Add a summary to track the learning rate.
            summaries.append(tf.summary.scalar('learning_rate', lr))

        # Create an optimizer that performs gradient descent.
        opt = tf.train.GradientDescentOptimizer(lr)
        opt = tf.train.SyncReplicasOptimizer(
            opt,
            replicas_to_aggregate=num_replicas_to_aggregate,
            total_num_replicas=num_workers,
            #use_locking=True)
            use_locking=False)
        # Calculate the gradients for the batch
        grads = opt.compute_gradients(loss)
        # Add histograms for gradients at the chief worker.
        if is_chief:
            for grad, var in grads:
                if grad is not None:
                    summaries.append(
                        tf.summary.histogram(var.op.name + '/gradients', grad))
        # apply gradients to variable
        train_op = opt.apply_gradients(grads, global_step=global_step)
        # Add histograms for trainable variables.
        if is_chief:
            for var in tf.trainable_variables():
                summaries.append(tf.summary.histogram(var.op.name, var))

        #variable_averages = tf.train.ExponentialMovingAverage(
        #      cifar100.MOVING_AVERAGE_DECAY, global_step)
        #variables_averages_op = variable_averages.apply(tf.trainable_variables())
        #train_op = tf.group(train_op, variables_averages_op)

        if is_chief:
            #Build the summary operation at the chief worker
            summary_op = tf.summary.merge(summaries)

    chief_queue_runner = opt.get_chief_queue_runner()
    init_token_op = opt.get_init_tokens_op()
    # Build an initialization operation to run below.
    init_op = tf.global_variables_initializer()
    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())

    sv = tf.train.Supervisor(is_chief=is_chief,
                             global_step=global_step,
                             init_op=init_op)
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement)

    with sv.prepare_or_wait_for_session(server.target,
                                        config=sess_config) as sess:
        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        # start sync queue runner and run the init token op at the chief worker
        queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        sv.start_queue_runners(sess, queue_runners)

        if is_chief:
            sv.start_queue_runners(sess, [chief_queue_runner])
            sess.run(init_token_op)
        #open the summary writer
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        t1 = time.time()
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size * num_workers
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration / num_workers
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))
            if step % 100 == 0:
                if is_chief:
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                if is_chief:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

        t2 = time.time()
        print('spent %f seconds to train %d step' % (t2 - t1, FLAGS.max_steps))
        logger.info('spent %f seconds to train %d step' %
                    (t2 - t1, FLAGS.max_steps))
        logger.info('last loss value: %.2f ' % loss_value)
def evaluate_images(images):  # 执行验证
    logits = cifar100.inference(images)
    load_trained_model(logits=logits)
示例#13
0
def evaluate_images(images):  # 执行验证
    logits = cifar100.inference(images)
    return logits