示例#1
0
 def testEndpointsReuse(self):
     inputs = create_test_input(2, 32, 32, 3)
     with slim.arg_scope(xception.xception_arg_scope()):
         _, end_points0 = xception.xception_65(inputs,
                                               num_classes=10,
                                               reuse=False)
     with slim.arg_scope(xception.xception_arg_scope()):
         _, end_points1 = xception.xception_65(inputs,
                                               num_classes=10,
                                               reuse=True)
     self.assertItemsEqual(end_points0.keys(), end_points1.keys())
示例#2
0
    def testUnknownBatchSize(self):
        batch = 2
        height, width = 65, 65
        global_pool = True
        num_classes = 10
        inputs = create_test_input(None, height, width, 3)

        with slim.arg_scope(xception.xception_arg_scope()):
            logits, _ = self._xception_small(inputs,
                                             num_classes,
                                             global_pool=global_pool,
                                             scope='xception')

        # does the model have the name 'xception/logits'?
        self.assertTrue(logits.op.name.startswith('xception/logits'))
        self.assertListEqual(logits.get_shape().as_list(),
                             [None, 1, 1, num_classes])

        images = create_test_input(batch, height, width, 3)

        with self.test_session() as sess:
            print('[tfTest] run testUnknownBatchSize()')
            sess.run(tf.global_variables_initializer())
            output = sess.run(logits, {inputs: images.eval()})
            self.assertEquals(output.shape, (batch, 1, 1, num_classes))
示例#3
0
 def testAtrousFullyConvolutionalValues(self):
     """Verify dense feature extraction with atrous convolution."""
     nominal_stride = 32
     for output_stride in [4, 8, 16, 32, None]:
         with slim.arg_scope(xception.xception_arg_scope()):
             with tf.Graph().as_default():
                 with self.test_session() as sess:
                     tf.set_random_seed(0)
                     inputs = create_test_input(2, 96, 97, 3)
                     # Dense feature extraction followed by subsampling.
                     output, _ = self._xception_small(
                         inputs,
                         None,
                         is_training=False,
                         global_pool=False,
                         output_stride=output_stride)
                     if output_stride is None:
                         factor = 1
                     else:
                         factor = nominal_stride // output_stride
                     output = resnet_utils.subsample(output, factor)
                     # Make the two networks use the same weights.
                     tf.get_variable_scope().reuse_variables()
                     # Feature extraction at the nominal network rate.
                     expected, _ = self._xception_small(inputs,
                                                        None,
                                                        is_training=False,
                                                        global_pool=False)
                     sess.run(tf.global_variables_initializer())
                     self.assertAllClose(output.eval(),
                                         expected.eval(),
                                         atol=1e-5,
                                         rtol=1e-5)
def load_model(sess):
    """

    Load TensorFlow model

    Args:
        sess: TensorFlow session

    """
    print("Loading model...")

    placeholder = tf.placeholder(shape=[None, image_size, image_size, 3],
                                 dtype=tf.float32,
                                 name='Placeholder_only')

    #Now create the inference model but set is_training=False
    with slim.arg_scope(xception_arg_scope()):
        logits, end_points = xception(placeholder,
                                      num_classes=NUM_CLASSES,
                                      is_training=False)

    # #get all the variables to restore from the checkpoint file and create the saver function to restore
    variables_to_restore = slim.get_variables_to_restore()

    #Just define the metrics to track without the loss or whatsoever
    probabilities = end_points['Predictions']
    predictions = tf.argmax(probabilities, 1)
    saver = tf.train.Saver()
    saver.restore(sess, '/model.ckpt')  # specify here which model to restore
    return predictions
示例#5
0
    def testClassificationEndPoints(self):
        global_pool = True
        num_classes = 10
        inputs = create_test_input(2, 224, 224, 3)

        with slim.arg_scope(xception.xception_arg_scope()):
            logits, end_points = self._xception_small(inputs,
                                                      num_classes=num_classes,
                                                      global_pool=global_pool,
                                                      scope='xception')

        print('[tf.Test] run testClassificationEndPoints()')
        # check the endpoint name
        self.assertTrue(logits.op.name.startswith('xception/logits'))

        # check the shape
        self.assertListEqual(logits.get_shape().as_list(),
                             [2, 1, 1, num_classes])

        self.assertTrue('predictions' in end_points)
        self.assertListEqual(end_points['predictions'].get_shape().as_list(),
                             [2, 1, 1, num_classes])

        self.assertTrue('global_pool' in end_points)
        self.assertListEqual(end_points['global_pool'].get_shape().as_list(),
                             [2, 1, 1, 16])
示例#6
0
def run():
    image_size = 299
    num_classes = 5
    logdir = './log'

    checkpoint_file = tf.train.latest_checkpoint(logdir)

    with tf.Graph().as_default() as graph:
        images = tf.placeholder(shape=[None, image_size, image_size, 3], dtype=tf.float32, name='Placeholder_only')

        with slim.arg_scope(xception_arg_scope()):
            logits, end_points = xception(images, num_classes=num_classes, is_training=False)

        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Setup graph def
        input_graph_def = graph.as_graph_def()
        output_node_names = "Xception/Predictions/Softmax"
        output_graph_name = "./frozen_model_xception.pb"

        with tf.Session() as sess:
            saver.restore(sess, checkpoint_file)

            # Exporting the graph
            print("Exporting graph...")
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(","))

            with tf.gfile.GFile(output_graph_name, "wb") as f:
                f.write(output_graph_def.SerializeToString())
示例#7
0
def main(_):
  if settings.FLAGS.job_name == "worker" and settings.FLAGS.task_index == 0:
    model_inputs.maybe_download_and_extract()
  ps_hosts = settings.FLAGS.ps_hosts.split(",")
  worker_hosts = settings.FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=settings.FLAGS.job_name,
                           task_index=settings.FLAGS.task_index)

  if settings.FLAGS.job_name == "ps":
    server.join()
  elif settings.FLAGS.job_name == "worker":
    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % settings.FLAGS.task_index,
        cluster=cluster)):
      isXception = trainer_functions.query_yes_no("Would you like to use the Xception model \n(if no, the model will default to that of the TensorFlow turorial)?")
      # Build model
      if isXception:
          images, labels = trainer_functions.distorted_inputs(isXception)
          with slim.arg_scope(xception.xception_arg_scope()):
            logits, end_points = xception.xception(images, num_classes = 10, is_training = True)
      else:
          images, labels = trainer_functions.distorted_inputs(isXception)
          logits = trainer_functions.tutorial_model(images)
      # Calculate loss.
      loss = trainer_functions.loss(logits, labels)
      global_step = tf.contrib.framework.get_or_create_global_step()

      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)

    # The StopAtStepHook handles stopping after running given steps.
    hooks=[tf.train.StopAtStepHook(last_step=settings.FLAGS.max_steps)]

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(settings.FLAGS.task_index == 0),
                                           checkpoint_dir="./train_logs",
                                           hooks=hooks) as mon_sess:
      prev_time = time.time()
      while not mon_sess.should_stop():
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.
        # mon_sess.run handles AbortedError in case of preempted PS.
        mon_sess.run(train_op)
        if mon_sess.run(global_step)%20 == 0:
          duration = time.time() - prev_time
          prev_time = time.time()
          examples_per_sec = settings.FLAGS.log_frequency * settings.FLAGS.batch_size / duration
          print ("examples/sec: %d" % examples_per_sec + ", loss: %f" % mon_sess.run(loss))
示例#8
0
 def testFullyConvolutionalUnknownHeightWidth(self):
     batch = 2
     height, width = 65, 65
     global_pool = False
     inputs = create_test_input(batch, None, None, 3)
     with slim.arg_scope(xception.xception_arg_scope()):
         output, _ = self._xception_small(inputs,
                                          None,
                                          global_pool=global_pool)
     self.assertListEqual(output.get_shape().as_list(),
                          [batch, None, None, 16])
     images = create_test_input(batch, height, width, 3)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         output = sess.run(output, {inputs: images.eval()})
         self.assertEquals(output.shape, (batch, 3, 3, 16))
示例#9
0
 def testFullyConvolutionalEndpointShapes(self):
     global_pool = False
     num_classes = 10
     inputs = create_test_input(2, 321, 321, 3)
     with slim.arg_scope(xception.xception_arg_scope()):
         _, end_points = self._xception_small(inputs,
                                              num_classes,
                                              global_pool=global_pool,
                                              scope='xception')
         endpoint_to_shape = {
             'xception/entry_flow/conv1_1': [2, 161, 161, 32],
             'xception/entry_flow/block1': [2, 81, 81, 1],
             'xception/entry_flow/block2': [2, 41, 41, 2],
             'xception/entry_flow/block4': [2, 21, 21, 4],
             'xception/middle_flow/block1': [2, 21, 21, 4],
             'xception/exit_flow/block1': [2, 11, 11, 8],
             'xception/exit_flow/block2': [2, 11, 11, 16]
         }
         for endpoint, shape in six.iteritems(endpoint_to_shape):
             self.assertListEqual(
                 end_points[endpoint].get_shape().as_list(), shape)
示例#10
0
 def testClassificationShapes(self):
     global_pool = True
     num_classes = 10
     inputs = create_test_input(2, 224, 224, 3)
     with slim.arg_scope(xception.xception_arg_scope()):
         _, end_points = self._xception_small(inputs,
                                              num_classes,
                                              global_pool=global_pool,
                                              scope='xception')
         endpoint_to_shape = {
             'xception/entry_flow/conv1_1': [2, 112, 112, 32],
             'xception/entry_flow/block1': [2, 56, 56, 1],
             'xception/entry_flow/block2': [2, 28, 28, 2],
             'xception/entry_flow/block4': [2, 14, 14, 4],
             'xception/middle_flow/block1': [2, 14, 14, 4],
             'xception/exit_flow/block1': [2, 7, 7, 8],
             'xception/exit_flow/block2': [2, 7, 7, 16]
         }
         for endpoint, shape in six.iteritems(endpoint_to_shape):
             self.assertListEqual(
                 end_points[endpoint].get_shape().as_list(), shape)
示例#11
0
def run():
    # Create the log directory here. Must be done here otherwise import will activate this unneededly.
    # 创建log目录
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    # ======================= TRAINING PROCESS(训练过程) =========================
    # Now we start to construct the graph and build our model
    # 现在我们开始构造图并建立我们的模型
    with tf.Graph().as_default() as graph:
        # Set the verbosity to INFO level
        # 设置日志的级别,会将日志级别为INFO的打印出
        tf.logging.set_verbosity(tf.logging.INFO)

        # First create the dataset and load one batch
        # 首先,创建数据集并加载一个批次
        dataset = get_split('train', dataset_dir, file_pattern=file_pattern)
        images, _, labels = load_batch(dataset, batch_size=batch_size)

        # Know the number steps to take before decaying the learning rate and batches per epoch
        num_batches_per_epoch = dataset.num_samples // batch_size
        num_steps_per_epoch = num_batches_per_epoch  # Because one step is one batch processed
        decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)

        # Create the model inference
        # 创建模型推理
        with slim.arg_scope(xception_arg_scope()):
            logits, end_points = xception(images, num_classes=dataset.num_classes, is_training=True)

        # Perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!)
        # 将标签编程one-hot形式
        one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)

        # Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
        # 计算损失
        loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits)
        total_loss = tf.losses.get_total_loss()  # obtain the regularization losses as well

        # Create the global step for monitoring the learning_rate and training.
        # 创建global_step
        global_step = get_or_create_global_step()

        # Define your exponentially decaying learning rate
        # 定义指数衰减的学习率
        lr = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=global_step,
            decay_steps=decay_steps,
            decay_rate=learning_rate_decay_factor,
            staircase=True)

        # Now we can define the optimizer that takes on the learning rate
        # 定义优化器
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        # optimizer = tf.train.RMSPropOptimizer(learning_rate = lr, momentum=0.9)

        # Create the train_op.
        # 创建训练操作
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        # State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        # 定义度量标准
        predictions = tf.argmax(end_points['Predictions'], 1)
        probabilities = end_points['Predictions']
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
        metrics_op = tf.group(accuracy_update, probabilities)

        # Now finally create all the summaries you need to monitor and group them into one summary op.
        # 创建summary
        tf.summary.scalar('losses/Total_Loss', total_loss)
        tf.summary.scalar('accuracy', accuracy)
        tf.summary.scalar('learning_rate', lr)
        my_summary_op = tf.summary.merge_all()

        # Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently.
        def train_step(sess, train_op, global_step):
            '''
            Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step
            '''
            # Check the time for each sess run
            start_time = time.time()
            total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op])
            time_elapsed = time.time() - start_time

            # Run the logging to print some results
            logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)

            return total_loss, global_step_count

        # Define your supervisor for running a managed session.
        # Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=log_dir, summary_op=None)

        # Run the managed session
        with sv.managed_session() as sess:
            for step in range(num_steps_per_epoch * num_epochs):
                # At the start of every epoch, show the vital information:
                if step % num_batches_per_epoch == 0:
                    logging.info('Epoch %s/%s', step / num_batches_per_epoch + 1, num_epochs)
                    learning_rate_value, accuracy_value = sess.run([lr, accuracy])
                    logging.info('Current Learning Rate: %s', learning_rate_value)
                    logging.info('Current Streaming Accuracy: %s', accuracy_value)

                    # optionally, print your logits and predictions for a sanity check that things are going fine.
                    logits_value, probabilities_value, predictions_value, labels_value = sess.run(
                        [logits, probabilities, predictions, labels])
                    print('logits: \n', logits_value[:5])
                    print('Probabilities: \n', probabilities_value[:5])
                    print('predictions: \n', predictions_value[:5])
                    print('Labels:\n:', labels_value[:5])

                # Log the summaries every 10 step.
                if step % 10 == 0:
                    loss, _ = train_step(sess, train_op, sv.global_step)
                    summaries = sess.run(my_summary_op)
                    sv.summary_computed(sess, summaries)

                # If not, simply run the training step
                else:
                    loss, _ = train_step(sess, train_op, sv.global_step)

            # We log the final training loss and accuracy
            logging.info('Final Loss: %s', loss)
            logging.info('Final Accuracy: %s', sess.run(accuracy))

            # Once all the training has been done, save the log files and checkpoint model
            logging.info('Finished training! Saving model to disk now.')
示例#12
0
    def model(inputs, region, is_training):
        """Constructs the ResNet model given the inputs."""
        if data_format == 'channels_first':
            # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
            # This provides a large performance boost on GPU. See
            # https://www.tensorflow.org/performance/performance_guide#data_formats
            inputs = tf.transpose(inputs, [0, 3, 1, 2])
            region = tf.transpose(region, [0, 3, 1, 2])

        # tf.logging.info('net shape: {}'.format(inputs.shape))
        # encoder
        with tf.contrib.slim.arg_scope(
                xception.xception_arg_scope(
                    batch_norm_decay=batch_norm_decay)):
            logits, end_points, low_level_features = xception.xception(
                inputs, region, num_classes=None, is_training=is_training)

        if is_training and pre_trained_model != None:
            #exclude = ['xception/logits', 'global_step']
            variables_to_restore = tf.contrib.slim.get_variables_to_restore(
                exclude=None)
            tf.train.init_from_checkpoint(
                pre_trained_model,
                {v.name.split(':')[0]: v
                 for v in variables_to_restore})

        inputs_size = tf.shape(inputs)[1:3]
        net = end_points['Logits']
        encoder_output = atrous_spatial_pyramid_pooling(
            net, batch_norm_decay, is_training)

        with tf.variable_scope("lstm"):
            with tf.contrib.slim.arg_scope(
                    xception.xception_arg_scope(
                        batch_norm_decay=batch_norm_decay)):
                with arg_scope([layers.batch_norm], is_training=is_training):
                    k_size = encoder_output.get_shape().as_list()[2]
                    #k_size = 21
                    net = layers_lib.conv2d(encoder_output,
                                            64, [1, 1],
                                            stride=1,
                                            scope="conv1_1x1")
                    #net = layers_lib.conv2d(net, 4, [1, 1], stride=1, scope="conv2_1x1")
                    rnn_input = tf.reshape(net, [-1, 4, k_size, k_size, 64])
                    cell_1 = tf.contrib.rnn.ConvLSTMCell(
                        conv_ndims=2,
                        input_shape=[k_size, k_size, 64],
                        output_channels=64,
                        kernel_shape=[3, 3])
                    #cell_2 = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[k_size, k_size, 256], output_channels=256, kernel_shape=[3, 3])
                    #cell_3 = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[k_size, k_size, 256], output_channels=256, kernel_shape=[3, 3])
                    multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell([cell_1])
                    dropout_rnn = tf.contrib.rnn.DropoutWrapper(
                        multi_rnn_cell, input_keep_prob=0.4)
                    #init_state = multi_rnn_cell.zero_state(12,dtype=tf.float32)
                    rnn_outputs, state = tf.nn.dynamic_rnn(cell=dropout_rnn,
                                                           inputs=rnn_input,
                                                           time_major=False,
                                                           dtype=tf.float32)

        with tf.variable_scope("decoder"):
            with tf.contrib.slim.arg_scope(
                    xception.xception_arg_scope(
                        batch_norm_decay=batch_norm_decay)):
                with arg_scope([layers.batch_norm], is_training=is_training):
                    with tf.variable_scope("low_level_features"):
                        low_level_features = layers_lib.conv2d(
                            low_level_features,
                            24, [1, 1],
                            stride=1,
                            scope='conv_1x1')
                        low_level_features_size = tf.shape(
                            low_level_features)[1:3]

                    with tf.variable_scope("upsampling_logits"):
                        encoder_output_re = tf.reshape(
                            rnn_outputs, [-1, k_size, k_size, 64])
                        net = layers_lib.conv2d(encoder_output_re,
                                                64, [3, 3],
                                                stride=1,
                                                scope='conv_3x3_1')
                        net = layers_lib.conv2d(net,
                                                64, [3, 3],
                                                stride=1,
                                                scope='conv_3x3_2')
                        net = tf.image.resize_bilinear(net,
                                                       low_level_features_size,
                                                       name='upsample_1')
                        net = tf.concat([net, low_level_features],
                                        axis=3,
                                        name='concat')
                        net = layers_lib.conv2d(net,
                                                44, [3, 3],
                                                stride=1,
                                                scope='conv_3x3_3')
                        net = layers_lib.conv2d(net,
                                                44, [3, 3],
                                                stride=1,
                                                scope='conv_3x3_4')
                        net = tf.image.resize_bilinear(net,
                                                       inputs_size,
                                                       name='upsample_2')
                        #print(inputs.get_shape().as_list())
                        low_level_features_two = layers_lib.conv2d(
                            inputs,
                            1, [1, 1],
                            stride=1,
                            scope='low_level_feature_conv_1x1')
                        net = tf.concat([net, low_level_features_two],
                                        axis=3,
                                        name='concat_2')
                        logits = layers_lib.conv2d(net,
                                                   num_classes, [1, 1],
                                                   activation_fn=None,
                                                   normalizer_fn=None,
                                                   scope='outputs')

        return logits
示例#13
0
 def testEndpointNames(self):
     global_pool = True
     num_classes = 10
     inputs = create_test_input(2, 224, 224, 3)
     with slim.arg_scope(xception.xception_arg_scope()):
         _, end_points = self._xception_small(inputs,
                                              num_classes=num_classes,
                                              global_pool=global_pool,
                                              scope='xception')
     expected = [
         'xception/entry_flow/conv1_1',
         'xception/entry_flow/conv1_2',
         'xception/entry_flow/block1/unit_1/xception_module/separable_conv1',
         'xception/entry_flow/block1/unit_1/xception_module/separable_conv2',
         'xception/entry_flow/block1/unit_1/xception_module/separable_conv3',
         'xception/entry_flow/block1/unit_1/xception_module/shortcut',
         'xception/entry_flow/block1/unit_1/xception_module',
         'xception/entry_flow/block1',
         'xception/entry_flow/block2/unit_1/xception_module/separable_conv1',
         'xception/entry_flow/block2/unit_1/xception_module/separable_conv2',
         'xception/entry_flow/block2/unit_1/xception_module/separable_conv3',
         'xception/entry_flow/block2/unit_1/xception_module/shortcut',
         'xception/entry_flow/block2/unit_1/xception_module',
         'xception/entry_flow/block2',
         'xception/entry_flow/block3/unit_1/xception_module/separable_conv1',
         'xception/entry_flow/block3/unit_1/xception_module/separable_conv2',
         'xception/entry_flow/block3/unit_1/xception_module/separable_conv3',
         'xception/entry_flow/block3/unit_1/xception_module/shortcut',
         'xception/entry_flow/block3/unit_1/xception_module',
         'xception/entry_flow/block3',
         'xception/entry_flow/block4/unit_1/xception_module/separable_conv1',
         'xception/entry_flow/block4/unit_1/xception_module/separable_conv2',
         'xception/entry_flow/block4/unit_1/xception_module/separable_conv3',
         'xception/entry_flow/block4/unit_1/xception_module/shortcut',
         'xception/entry_flow/block4/unit_1/xception_module',
         'xception/entry_flow/block4',
         'xception/middle_flow/block1/unit_1/xception_module/separable_conv1',
         'xception/middle_flow/block1/unit_1/xception_module/separable_conv2',
         'xception/middle_flow/block1/unit_1/xception_module/separable_conv3',
         'xception/middle_flow/block1/unit_1/xception_module',
         'xception/middle_flow/block1/unit_2/xception_module/separable_conv1',
         'xception/middle_flow/block1/unit_2/xception_module/separable_conv2',
         'xception/middle_flow/block1/unit_2/xception_module/separable_conv3',
         'xception/middle_flow/block1/unit_2/xception_module',
         'xception/middle_flow/block1',
         'xception/exit_flow/block1/unit_1/xception_module/separable_conv1',
         'xception/exit_flow/block1/unit_1/xception_module/separable_conv2',
         'xception/exit_flow/block1/unit_1/xception_module/separable_conv3',
         'xception/exit_flow/block1/unit_1/xception_module/shortcut',
         'xception/exit_flow/block1/unit_1/xception_module',
         'xception/exit_flow/block1',
         'xception/exit_flow/block2/unit_1/xception_module/separable_conv1',
         'xception/exit_flow/block2/unit_1/xception_module/separable_conv2',
         'xception/exit_flow/block2/unit_1/xception_module/separable_conv3',
         'xception/exit_flow/block2/unit_1/xception_module',
         'xception/exit_flow/block2',
         'global_pool',
         'xception/logits',
         'predictions',
     ]
     self.assertItemsEqual(end_points.keys(), expected)
def run():
    #Create log_dir for evaluation information
    if not os.path.exists(log_eval):
        os.mkdir(log_eval)

    #Just construct the graph from scratch again
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)
        #Get the dataset first and load one batch of validation images and labels tensors. Set is_training as False so as to use the evaluation preprocessing
        dataset = get_split('validation', dataset_dir)
        images, raw_images, labels = load_batch(dataset,
                                                batch_size=batch_size,
                                                is_training=False)

        #Create some information about the training steps
        num_batches_per_epoch = dataset.num_samples / batch_size
        num_steps_per_epoch = num_batches_per_epoch

        #Now create the inference model but set is_training=False
        with slim.arg_scope(xception_arg_scope()):
            logits, end_points = xception(images,
                                          num_classes=dataset.num_classes,
                                          is_training=False)

        # #get all the variables to restore from the checkpoint file and create the saver function to restore
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #Just define the metrics to track without the loss or whatsoever
        probabilities = end_points['Predictions']
        predictions = tf.argmax(probabilities, 1)

        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
            predictions, labels)
        metrics_op = tf.group(accuracy_update)

        #Create the global step and an increment op for monitoring
        global_step = get_or_create_global_step()
        global_step_op = tf.assign(
            global_step, global_step + 1
        )  #no apply_gradient method so manually increasing the global_step

        #Create a evaluation step function
        def eval_step(sess, metrics_op, global_step):
            '''
            Simply takes in a session, runs the metrics op and some logging information.
            '''
            start_time = time.time()
            _, global_step_count, accuracy_value = sess.run(
                [metrics_op, global_step_op, accuracy])
            time_elapsed = time.time() - start_time

            #Log some information
            logging.info(
                'Global Step %s: Streaming Accuracy: %.4f (%.2f sec/step)',
                global_step_count, accuracy_value, time_elapsed)

            return accuracy_value

        #Define some scalar quantities to monitor
        tf.summary.scalar('Validation_Accuracy', accuracy)
        ''' confusion matrix summaries '''
        with tf.name_scope('confusion_matrix'):
            confusion_matrix = tf.as_string(
                tf.confusion_matrix(labels=labels,
                                    predictions=predictions,
                                    num_classes=num_classes,
                                    name='Confusion'))

            confusion_mat = tf.summary.text("confusion_matrix",
                                            confusion_matrix)

        my_summary_op = tf.summary.merge_all()

        #Get your supervisor
        sv = tf.train.Supervisor(logdir=log_eval,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Now we are ready to run in one session
        with sv.managed_session() as sess:
            for step in range(int(num_batches_per_epoch * num_epochs)):
                #print vital information every start of the epoch as always
                if step % num_batches_per_epoch == 0:
                    logging.info('Epoch: %s/%s',
                                 step / num_batches_per_epoch + 1, num_epochs)
                    logging.info('Current Streaming Accuracy: %.4f',
                                 sess.run(accuracy))

                #Compute summaries every 10 steps and continue evaluating
                if step % 10 == 0:
                    eval_step(sess,
                              metrics_op=metrics_op,
                              global_step=sv.global_step)
                    summaries = sess.run(my_summary_op)
                    sv.summary_computed(sess, summaries)

                #Otherwise just run as per normal
                else:
                    eval_step(sess,
                              metrics_op=metrics_op,
                              global_step=sv.global_step)

            #At the end of all the evaluation, show the final accuracy
            logging.info('Final Streaming Accuracy: %.4f', sess.run(accuracy))

            #Now we want to visualize the last batch's images just to see what our model has predicted
            raw_images, labels, predictions, probabilities = sess.run(
                [raw_images, labels, predictions, probabilities])
            for i in range(10):
                image, label, prediction, probability = raw_images[i], labels[
                    i], predictions[i], probabilities[i]
                prediction_name, label_name = dataset.labels_to_name[
                    prediction], dataset.labels_to_name[label]
                text = 'Prediction: %s \n Ground Truth: %s \n Probability: %s' % (
                    prediction_name, label_name, probability[prediction])
                img_plot = plt.imshow(image)

                #Set up the plot and hide axes
                plt.title(text)
                img_plot.axes.get_yaxis().set_ticks([])
                img_plot.axes.get_xaxis().set_ticks([])
                plt.show()

            logging.info(
                'Model evaluation has completed! Visit TensorBoard for more information regarding your evaluation.'
            )
            sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
    
finaltest_init_op = iterator.make_initializer(final_test_dataset)

X,Y_name = iterator.get_next()
###data.api end

########Is_training
Is_training = tf.placeholder(tf.bool)#标记是否为True
###end
    
###dropout
KEEP_PROB = tf.placeholder(tf.float32)
###dropout end
  
###网络定义从这里开始
with slim.arg_scope(xception_arg_scope()):
    Y_prediction,end_points = xception(X,
                     num_classes=16,
                     is_training=Is_training,
                     scope='xception',
                     keep_prob=KEEP_PROB)
###网络结构这里结束
Y_softmax = tf.nn.softmax(Y_prediction)

    
initialization()#初始化函数,包括初始化训练集和测试集,得到训练集和测试集的个数,取出id对应的种类
preprocess.main()#creat TFrecord

variables_to_restore = slim.get_variables_to_restore()
###saver
saver = tf.train.Saver(variables_to_restore,max_to_keep = 1)  # 保存所有的变量,最多保存10个