def testEndpointsReuse(self): inputs = create_test_input(2, 32, 32, 3) with slim.arg_scope(xception.xception_arg_scope()): _, end_points0 = xception.xception_65(inputs, num_classes=10, reuse=False) with slim.arg_scope(xception.xception_arg_scope()): _, end_points1 = xception.xception_65(inputs, num_classes=10, reuse=True) self.assertItemsEqual(end_points0.keys(), end_points1.keys())
def testUnknownBatchSize(self): batch = 2 height, width = 65, 65 global_pool = True num_classes = 10 inputs = create_test_input(None, height, width, 3) with slim.arg_scope(xception.xception_arg_scope()): logits, _ = self._xception_small(inputs, num_classes, global_pool=global_pool, scope='xception') # does the model have the name 'xception/logits'? self.assertTrue(logits.op.name.startswith('xception/logits')) self.assertListEqual(logits.get_shape().as_list(), [None, 1, 1, num_classes]) images = create_test_input(batch, height, width, 3) with self.test_session() as sess: print('[tfTest] run testUnknownBatchSize()') sess.run(tf.global_variables_initializer()) output = sess.run(logits, {inputs: images.eval()}) self.assertEquals(output.shape, (batch, 1, 1, num_classes))
def testAtrousFullyConvolutionalValues(self): """Verify dense feature extraction with atrous convolution.""" nominal_stride = 32 for output_stride in [4, 8, 16, 32, None]: with slim.arg_scope(xception.xception_arg_scope()): with tf.Graph().as_default(): with self.test_session() as sess: tf.set_random_seed(0) inputs = create_test_input(2, 96, 97, 3) # Dense feature extraction followed by subsampling. output, _ = self._xception_small( inputs, None, is_training=False, global_pool=False, output_stride=output_stride) if output_stride is None: factor = 1 else: factor = nominal_stride // output_stride output = resnet_utils.subsample(output, factor) # Make the two networks use the same weights. tf.get_variable_scope().reuse_variables() # Feature extraction at the nominal network rate. expected, _ = self._xception_small(inputs, None, is_training=False, global_pool=False) sess.run(tf.global_variables_initializer()) self.assertAllClose(output.eval(), expected.eval(), atol=1e-5, rtol=1e-5)
def load_model(sess): """ Load TensorFlow model Args: sess: TensorFlow session """ print("Loading model...") placeholder = tf.placeholder(shape=[None, image_size, image_size, 3], dtype=tf.float32, name='Placeholder_only') #Now create the inference model but set is_training=False with slim.arg_scope(xception_arg_scope()): logits, end_points = xception(placeholder, num_classes=NUM_CLASSES, is_training=False) # #get all the variables to restore from the checkpoint file and create the saver function to restore variables_to_restore = slim.get_variables_to_restore() #Just define the metrics to track without the loss or whatsoever probabilities = end_points['Predictions'] predictions = tf.argmax(probabilities, 1) saver = tf.train.Saver() saver.restore(sess, '/model.ckpt') # specify here which model to restore return predictions
def testClassificationEndPoints(self): global_pool = True num_classes = 10 inputs = create_test_input(2, 224, 224, 3) with slim.arg_scope(xception.xception_arg_scope()): logits, end_points = self._xception_small(inputs, num_classes=num_classes, global_pool=global_pool, scope='xception') print('[tf.Test] run testClassificationEndPoints()') # check the endpoint name self.assertTrue(logits.op.name.startswith('xception/logits')) # check the shape self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes]) self.assertTrue('predictions' in end_points) self.assertListEqual(end_points['predictions'].get_shape().as_list(), [2, 1, 1, num_classes]) self.assertTrue('global_pool' in end_points) self.assertListEqual(end_points['global_pool'].get_shape().as_list(), [2, 1, 1, 16])
def run(): image_size = 299 num_classes = 5 logdir = './log' checkpoint_file = tf.train.latest_checkpoint(logdir) with tf.Graph().as_default() as graph: images = tf.placeholder(shape=[None, image_size, image_size, 3], dtype=tf.float32, name='Placeholder_only') with slim.arg_scope(xception_arg_scope()): logits, end_points = xception(images, num_classes=num_classes, is_training=False) variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver(variables_to_restore) # Setup graph def input_graph_def = graph.as_graph_def() output_node_names = "Xception/Predictions/Softmax" output_graph_name = "./frozen_model_xception.pb" with tf.Session() as sess: saver.restore(sess, checkpoint_file) # Exporting the graph print("Exporting graph...") output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph_name, "wb") as f: f.write(output_graph_def.SerializeToString())
def main(_): if settings.FLAGS.job_name == "worker" and settings.FLAGS.task_index == 0: model_inputs.maybe_download_and_extract() ps_hosts = settings.FLAGS.ps_hosts.split(",") worker_hosts = settings.FLAGS.worker_hosts.split(",") # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) # Create and start a server for the local task. server = tf.train.Server(cluster, job_name=settings.FLAGS.job_name, task_index=settings.FLAGS.task_index) if settings.FLAGS.job_name == "ps": server.join() elif settings.FLAGS.job_name == "worker": # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % settings.FLAGS.task_index, cluster=cluster)): isXception = trainer_functions.query_yes_no("Would you like to use the Xception model \n(if no, the model will default to that of the TensorFlow turorial)?") # Build model if isXception: images, labels = trainer_functions.distorted_inputs(isXception) with slim.arg_scope(xception.xception_arg_scope()): logits, end_points = xception.xception(images, num_classes = 10, is_training = True) else: images, labels = trainer_functions.distorted_inputs(isXception) logits = trainer_functions.tutorial_model(images) # Calculate loss. loss = trainer_functions.loss(logits, labels) global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) # The StopAtStepHook handles stopping after running given steps. hooks=[tf.train.StopAtStepHook(last_step=settings.FLAGS.max_steps)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(settings.FLAGS.task_index == 0), checkpoint_dir="./train_logs", hooks=hooks) as mon_sess: prev_time = time.time() while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. mon_sess.run(train_op) if mon_sess.run(global_step)%20 == 0: duration = time.time() - prev_time prev_time = time.time() examples_per_sec = settings.FLAGS.log_frequency * settings.FLAGS.batch_size / duration print ("examples/sec: %d" % examples_per_sec + ", loss: %f" % mon_sess.run(loss))
def testFullyConvolutionalUnknownHeightWidth(self): batch = 2 height, width = 65, 65 global_pool = False inputs = create_test_input(batch, None, None, 3) with slim.arg_scope(xception.xception_arg_scope()): output, _ = self._xception_small(inputs, None, global_pool=global_pool) self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 16]) images = create_test_input(batch, height, width, 3) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) output = sess.run(output, {inputs: images.eval()}) self.assertEquals(output.shape, (batch, 3, 3, 16))
def testFullyConvolutionalEndpointShapes(self): global_pool = False num_classes = 10 inputs = create_test_input(2, 321, 321, 3) with slim.arg_scope(xception.xception_arg_scope()): _, end_points = self._xception_small(inputs, num_classes, global_pool=global_pool, scope='xception') endpoint_to_shape = { 'xception/entry_flow/conv1_1': [2, 161, 161, 32], 'xception/entry_flow/block1': [2, 81, 81, 1], 'xception/entry_flow/block2': [2, 41, 41, 2], 'xception/entry_flow/block4': [2, 21, 21, 4], 'xception/middle_flow/block1': [2, 21, 21, 4], 'xception/exit_flow/block1': [2, 11, 11, 8], 'xception/exit_flow/block2': [2, 11, 11, 16] } for endpoint, shape in six.iteritems(endpoint_to_shape): self.assertListEqual( end_points[endpoint].get_shape().as_list(), shape)
def testClassificationShapes(self): global_pool = True num_classes = 10 inputs = create_test_input(2, 224, 224, 3) with slim.arg_scope(xception.xception_arg_scope()): _, end_points = self._xception_small(inputs, num_classes, global_pool=global_pool, scope='xception') endpoint_to_shape = { 'xception/entry_flow/conv1_1': [2, 112, 112, 32], 'xception/entry_flow/block1': [2, 56, 56, 1], 'xception/entry_flow/block2': [2, 28, 28, 2], 'xception/entry_flow/block4': [2, 14, 14, 4], 'xception/middle_flow/block1': [2, 14, 14, 4], 'xception/exit_flow/block1': [2, 7, 7, 8], 'xception/exit_flow/block2': [2, 7, 7, 16] } for endpoint, shape in six.iteritems(endpoint_to_shape): self.assertListEqual( end_points[endpoint].get_shape().as_list(), shape)
def run(): # Create the log directory here. Must be done here otherwise import will activate this unneededly. # 创建log目录 if not os.path.exists(log_dir): os.mkdir(log_dir) # ======================= TRAINING PROCESS(训练过程) ========================= # Now we start to construct the graph and build our model # 现在我们开始构造图并建立我们的模型 with tf.Graph().as_default() as graph: # Set the verbosity to INFO level # 设置日志的级别,会将日志级别为INFO的打印出 tf.logging.set_verbosity(tf.logging.INFO) # First create the dataset and load one batch # 首先,创建数据集并加载一个批次 dataset = get_split('train', dataset_dir, file_pattern=file_pattern) images, _, labels = load_batch(dataset, batch_size=batch_size) # Know the number steps to take before decaying the learning rate and batches per epoch num_batches_per_epoch = dataset.num_samples // batch_size num_steps_per_epoch = num_batches_per_epoch # Because one step is one batch processed decay_steps = int(num_epochs_before_decay * num_steps_per_epoch) # Create the model inference # 创建模型推理 with slim.arg_scope(xception_arg_scope()): logits, end_points = xception(images, num_classes=dataset.num_classes, is_training=True) # Perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!) # 将标签编程one-hot形式 one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) # Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks # 计算损失 loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits) total_loss = tf.losses.get_total_loss() # obtain the regularization losses as well # Create the global step for monitoring the learning_rate and training. # 创建global_step global_step = get_or_create_global_step() # Define your exponentially decaying learning rate # 定义指数衰减的学习率 lr = tf.train.exponential_decay( learning_rate=initial_learning_rate, global_step=global_step, decay_steps=decay_steps, decay_rate=learning_rate_decay_factor, staircase=True) # Now we can define the optimizer that takes on the learning rate # 定义优化器 optimizer = tf.train.AdamOptimizer(learning_rate=lr) # optimizer = tf.train.RMSPropOptimizer(learning_rate = lr, momentum=0.9) # Create the train_op. # 创建训练操作 train_op = slim.learning.create_train_op(total_loss, optimizer) # State the metrics that you want to predict. We get a predictions that is not one_hot_encoded. # 定义度量标准 predictions = tf.argmax(end_points['Predictions'], 1) probabilities = end_points['Predictions'] accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels) metrics_op = tf.group(accuracy_update, probabilities) # Now finally create all the summaries you need to monitor and group them into one summary op. # 创建summary tf.summary.scalar('losses/Total_Loss', total_loss) tf.summary.scalar('accuracy', accuracy) tf.summary.scalar('learning_rate', lr) my_summary_op = tf.summary.merge_all() # Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently. def train_step(sess, train_op, global_step): ''' Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step ''' # Check the time for each sess run start_time = time.time() total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op]) time_elapsed = time.time() - start_time # Run the logging to print some results logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) return total_loss, global_step_count # Define your supervisor for running a managed session. # Do not run the summary_op automatically or else it will consume too much memory sv = tf.train.Supervisor(logdir=log_dir, summary_op=None) # Run the managed session with sv.managed_session() as sess: for step in range(num_steps_per_epoch * num_epochs): # At the start of every epoch, show the vital information: if step % num_batches_per_epoch == 0: logging.info('Epoch %s/%s', step / num_batches_per_epoch + 1, num_epochs) learning_rate_value, accuracy_value = sess.run([lr, accuracy]) logging.info('Current Learning Rate: %s', learning_rate_value) logging.info('Current Streaming Accuracy: %s', accuracy_value) # optionally, print your logits and predictions for a sanity check that things are going fine. logits_value, probabilities_value, predictions_value, labels_value = sess.run( [logits, probabilities, predictions, labels]) print('logits: \n', logits_value[:5]) print('Probabilities: \n', probabilities_value[:5]) print('predictions: \n', predictions_value[:5]) print('Labels:\n:', labels_value[:5]) # Log the summaries every 10 step. if step % 10 == 0: loss, _ = train_step(sess, train_op, sv.global_step) summaries = sess.run(my_summary_op) sv.summary_computed(sess, summaries) # If not, simply run the training step else: loss, _ = train_step(sess, train_op, sv.global_step) # We log the final training loss and accuracy logging.info('Final Loss: %s', loss) logging.info('Final Accuracy: %s', sess.run(accuracy)) # Once all the training has been done, save the log files and checkpoint model logging.info('Finished training! Saving model to disk now.')
def model(inputs, region, is_training): """Constructs the ResNet model given the inputs.""" if data_format == 'channels_first': # Convert the inputs from channels_last (NHWC) to channels_first (NCHW). # This provides a large performance boost on GPU. See # https://www.tensorflow.org/performance/performance_guide#data_formats inputs = tf.transpose(inputs, [0, 3, 1, 2]) region = tf.transpose(region, [0, 3, 1, 2]) # tf.logging.info('net shape: {}'.format(inputs.shape)) # encoder with tf.contrib.slim.arg_scope( xception.xception_arg_scope( batch_norm_decay=batch_norm_decay)): logits, end_points, low_level_features = xception.xception( inputs, region, num_classes=None, is_training=is_training) if is_training and pre_trained_model != None: #exclude = ['xception/logits', 'global_step'] variables_to_restore = tf.contrib.slim.get_variables_to_restore( exclude=None) tf.train.init_from_checkpoint( pre_trained_model, {v.name.split(':')[0]: v for v in variables_to_restore}) inputs_size = tf.shape(inputs)[1:3] net = end_points['Logits'] encoder_output = atrous_spatial_pyramid_pooling( net, batch_norm_decay, is_training) with tf.variable_scope("lstm"): with tf.contrib.slim.arg_scope( xception.xception_arg_scope( batch_norm_decay=batch_norm_decay)): with arg_scope([layers.batch_norm], is_training=is_training): k_size = encoder_output.get_shape().as_list()[2] #k_size = 21 net = layers_lib.conv2d(encoder_output, 64, [1, 1], stride=1, scope="conv1_1x1") #net = layers_lib.conv2d(net, 4, [1, 1], stride=1, scope="conv2_1x1") rnn_input = tf.reshape(net, [-1, 4, k_size, k_size, 64]) cell_1 = tf.contrib.rnn.ConvLSTMCell( conv_ndims=2, input_shape=[k_size, k_size, 64], output_channels=64, kernel_shape=[3, 3]) #cell_2 = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[k_size, k_size, 256], output_channels=256, kernel_shape=[3, 3]) #cell_3 = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[k_size, k_size, 256], output_channels=256, kernel_shape=[3, 3]) multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell([cell_1]) dropout_rnn = tf.contrib.rnn.DropoutWrapper( multi_rnn_cell, input_keep_prob=0.4) #init_state = multi_rnn_cell.zero_state(12,dtype=tf.float32) rnn_outputs, state = tf.nn.dynamic_rnn(cell=dropout_rnn, inputs=rnn_input, time_major=False, dtype=tf.float32) with tf.variable_scope("decoder"): with tf.contrib.slim.arg_scope( xception.xception_arg_scope( batch_norm_decay=batch_norm_decay)): with arg_scope([layers.batch_norm], is_training=is_training): with tf.variable_scope("low_level_features"): low_level_features = layers_lib.conv2d( low_level_features, 24, [1, 1], stride=1, scope='conv_1x1') low_level_features_size = tf.shape( low_level_features)[1:3] with tf.variable_scope("upsampling_logits"): encoder_output_re = tf.reshape( rnn_outputs, [-1, k_size, k_size, 64]) net = layers_lib.conv2d(encoder_output_re, 64, [3, 3], stride=1, scope='conv_3x3_1') net = layers_lib.conv2d(net, 64, [3, 3], stride=1, scope='conv_3x3_2') net = tf.image.resize_bilinear(net, low_level_features_size, name='upsample_1') net = tf.concat([net, low_level_features], axis=3, name='concat') net = layers_lib.conv2d(net, 44, [3, 3], stride=1, scope='conv_3x3_3') net = layers_lib.conv2d(net, 44, [3, 3], stride=1, scope='conv_3x3_4') net = tf.image.resize_bilinear(net, inputs_size, name='upsample_2') #print(inputs.get_shape().as_list()) low_level_features_two = layers_lib.conv2d( inputs, 1, [1, 1], stride=1, scope='low_level_feature_conv_1x1') net = tf.concat([net, low_level_features_two], axis=3, name='concat_2') logits = layers_lib.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='outputs') return logits
def testEndpointNames(self): global_pool = True num_classes = 10 inputs = create_test_input(2, 224, 224, 3) with slim.arg_scope(xception.xception_arg_scope()): _, end_points = self._xception_small(inputs, num_classes=num_classes, global_pool=global_pool, scope='xception') expected = [ 'xception/entry_flow/conv1_1', 'xception/entry_flow/conv1_2', 'xception/entry_flow/block1/unit_1/xception_module/separable_conv1', 'xception/entry_flow/block1/unit_1/xception_module/separable_conv2', 'xception/entry_flow/block1/unit_1/xception_module/separable_conv3', 'xception/entry_flow/block1/unit_1/xception_module/shortcut', 'xception/entry_flow/block1/unit_1/xception_module', 'xception/entry_flow/block1', 'xception/entry_flow/block2/unit_1/xception_module/separable_conv1', 'xception/entry_flow/block2/unit_1/xception_module/separable_conv2', 'xception/entry_flow/block2/unit_1/xception_module/separable_conv3', 'xception/entry_flow/block2/unit_1/xception_module/shortcut', 'xception/entry_flow/block2/unit_1/xception_module', 'xception/entry_flow/block2', 'xception/entry_flow/block3/unit_1/xception_module/separable_conv1', 'xception/entry_flow/block3/unit_1/xception_module/separable_conv2', 'xception/entry_flow/block3/unit_1/xception_module/separable_conv3', 'xception/entry_flow/block3/unit_1/xception_module/shortcut', 'xception/entry_flow/block3/unit_1/xception_module', 'xception/entry_flow/block3', 'xception/entry_flow/block4/unit_1/xception_module/separable_conv1', 'xception/entry_flow/block4/unit_1/xception_module/separable_conv2', 'xception/entry_flow/block4/unit_1/xception_module/separable_conv3', 'xception/entry_flow/block4/unit_1/xception_module/shortcut', 'xception/entry_flow/block4/unit_1/xception_module', 'xception/entry_flow/block4', 'xception/middle_flow/block1/unit_1/xception_module/separable_conv1', 'xception/middle_flow/block1/unit_1/xception_module/separable_conv2', 'xception/middle_flow/block1/unit_1/xception_module/separable_conv3', 'xception/middle_flow/block1/unit_1/xception_module', 'xception/middle_flow/block1/unit_2/xception_module/separable_conv1', 'xception/middle_flow/block1/unit_2/xception_module/separable_conv2', 'xception/middle_flow/block1/unit_2/xception_module/separable_conv3', 'xception/middle_flow/block1/unit_2/xception_module', 'xception/middle_flow/block1', 'xception/exit_flow/block1/unit_1/xception_module/separable_conv1', 'xception/exit_flow/block1/unit_1/xception_module/separable_conv2', 'xception/exit_flow/block1/unit_1/xception_module/separable_conv3', 'xception/exit_flow/block1/unit_1/xception_module/shortcut', 'xception/exit_flow/block1/unit_1/xception_module', 'xception/exit_flow/block1', 'xception/exit_flow/block2/unit_1/xception_module/separable_conv1', 'xception/exit_flow/block2/unit_1/xception_module/separable_conv2', 'xception/exit_flow/block2/unit_1/xception_module/separable_conv3', 'xception/exit_flow/block2/unit_1/xception_module', 'xception/exit_flow/block2', 'global_pool', 'xception/logits', 'predictions', ] self.assertItemsEqual(end_points.keys(), expected)
def run(): #Create log_dir for evaluation information if not os.path.exists(log_eval): os.mkdir(log_eval) #Just construct the graph from scratch again with tf.Graph().as_default() as graph: tf.logging.set_verbosity(tf.logging.INFO) #Get the dataset first and load one batch of validation images and labels tensors. Set is_training as False so as to use the evaluation preprocessing dataset = get_split('validation', dataset_dir) images, raw_images, labels = load_batch(dataset, batch_size=batch_size, is_training=False) #Create some information about the training steps num_batches_per_epoch = dataset.num_samples / batch_size num_steps_per_epoch = num_batches_per_epoch #Now create the inference model but set is_training=False with slim.arg_scope(xception_arg_scope()): logits, end_points = xception(images, num_classes=dataset.num_classes, is_training=False) # #get all the variables to restore from the checkpoint file and create the saver function to restore variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver(variables_to_restore) def restore_fn(sess): return saver.restore(sess, checkpoint_file) #Just define the metrics to track without the loss or whatsoever probabilities = end_points['Predictions'] predictions = tf.argmax(probabilities, 1) accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy( predictions, labels) metrics_op = tf.group(accuracy_update) #Create the global step and an increment op for monitoring global_step = get_or_create_global_step() global_step_op = tf.assign( global_step, global_step + 1 ) #no apply_gradient method so manually increasing the global_step #Create a evaluation step function def eval_step(sess, metrics_op, global_step): ''' Simply takes in a session, runs the metrics op and some logging information. ''' start_time = time.time() _, global_step_count, accuracy_value = sess.run( [metrics_op, global_step_op, accuracy]) time_elapsed = time.time() - start_time #Log some information logging.info( 'Global Step %s: Streaming Accuracy: %.4f (%.2f sec/step)', global_step_count, accuracy_value, time_elapsed) return accuracy_value #Define some scalar quantities to monitor tf.summary.scalar('Validation_Accuracy', accuracy) ''' confusion matrix summaries ''' with tf.name_scope('confusion_matrix'): confusion_matrix = tf.as_string( tf.confusion_matrix(labels=labels, predictions=predictions, num_classes=num_classes, name='Confusion')) confusion_mat = tf.summary.text("confusion_matrix", confusion_matrix) my_summary_op = tf.summary.merge_all() #Get your supervisor sv = tf.train.Supervisor(logdir=log_eval, summary_op=None, init_fn=restore_fn) #Now we are ready to run in one session with sv.managed_session() as sess: for step in range(int(num_batches_per_epoch * num_epochs)): #print vital information every start of the epoch as always if step % num_batches_per_epoch == 0: logging.info('Epoch: %s/%s', step / num_batches_per_epoch + 1, num_epochs) logging.info('Current Streaming Accuracy: %.4f', sess.run(accuracy)) #Compute summaries every 10 steps and continue evaluating if step % 10 == 0: eval_step(sess, metrics_op=metrics_op, global_step=sv.global_step) summaries = sess.run(my_summary_op) sv.summary_computed(sess, summaries) #Otherwise just run as per normal else: eval_step(sess, metrics_op=metrics_op, global_step=sv.global_step) #At the end of all the evaluation, show the final accuracy logging.info('Final Streaming Accuracy: %.4f', sess.run(accuracy)) #Now we want to visualize the last batch's images just to see what our model has predicted raw_images, labels, predictions, probabilities = sess.run( [raw_images, labels, predictions, probabilities]) for i in range(10): image, label, prediction, probability = raw_images[i], labels[ i], predictions[i], probabilities[i] prediction_name, label_name = dataset.labels_to_name[ prediction], dataset.labels_to_name[label] text = 'Prediction: %s \n Ground Truth: %s \n Probability: %s' % ( prediction_name, label_name, probability[prediction]) img_plot = plt.imshow(image) #Set up the plot and hide axes plt.title(text) img_plot.axes.get_yaxis().set_ticks([]) img_plot.axes.get_xaxis().set_ticks([]) plt.show() logging.info( 'Model evaluation has completed! Visit TensorBoard for more information regarding your evaluation.' ) sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
finaltest_init_op = iterator.make_initializer(final_test_dataset) X,Y_name = iterator.get_next() ###data.api end ########Is_training Is_training = tf.placeholder(tf.bool)#标记是否为True ###end ###dropout KEEP_PROB = tf.placeholder(tf.float32) ###dropout end ###网络定义从这里开始 with slim.arg_scope(xception_arg_scope()): Y_prediction,end_points = xception(X, num_classes=16, is_training=Is_training, scope='xception', keep_prob=KEEP_PROB) ###网络结构这里结束 Y_softmax = tf.nn.softmax(Y_prediction) initialization()#初始化函数,包括初始化训练集和测试集,得到训练集和测试集的个数,取出id对应的种类 preprocess.main()#creat TFrecord variables_to_restore = slim.get_variables_to_restore() ###saver saver = tf.train.Saver(variables_to_restore,max_to_keep = 1) # 保存所有的变量,最多保存10个