예제 #1
0
def rebuild_graph(sess, checkpoint_dir, input_image, batch_size, feature_dim):
    checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

    num_initial_blocks = 1
    skip_connections = False
    stage_two_repeat = 2
    
    with slim.arg_scope(ENet_arg_scope()):
        _, _ = ENet(input_image,
                     num_classes=12,
                     batch_size=batch_size,
                     is_training=True,
                     reuse=None,
                     num_initial_blocks=num_initial_blocks,
                     stage_two_repeat=stage_two_repeat,
                     skip_connections=skip_connections)

    graph = tf.get_default_graph()
    last_prelu = graph.get_tensor_by_name('ENet/bottleneck5_1_last_prelu:0')
    logits = slim.conv2d_transpose(last_prelu, feature_dim, [2,2], stride=2, 
                                    scope='Instance/transfer_layer/conv2d_transpose')

    variables_to_restore = slim.get_variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess, checkpoint)

    return logits
예제 #2
0
파일: MENet.py 프로젝트: YoungBim/MENet
    def MENet_Model(self):
        with slim.arg_scope(ENet_arg_scope(weight_decay=self.opt.weight_decay)):
            # Define the shared encoder
            Encoder = ENetEncoder(          self.batch_images,
                                            batch_size=self.opt.batch_size,
                                            is_training=True,
                                            reuse=None,
                                            num_initial_blocks=self.opt.num_initial_blocks,
                                            stage_two_repeat=self.opt.stage_two_repeat,
                                            skip_connections=self.opt.skip_connections)

            # Collect tensors that are useful later (e.g. tf summary)
            self.predictions = {}

            # Define the decoder(s)
            for task in self.Tasks:
                if (task == 'segmentation'):
                    logits, probabilities = ENetSegDecoder(  Encoder,
                                                             self.opt.num_classes,
                                                             is_training=True,
                                                             reuse=None,
                                                             stage_two_repeat=self.opt.stage_two_repeat,
                                                             skip_connections=self.opt.skip_connections)
                    self.probabilities = probabilities
                    self.predictions[task] = tf.identity(logits, name=task + '_pred')

                elif (task == 'depth'):
                    disparity = ENetDepthDecoder(   Encoder,
                                                    skip_connections=self.opt.skip_connections,
                                                    is_training=True,
                                                    reuse=None)
                    self.predictions[task] = tf.identity(disparity, name=task + '_pred')
예제 #3
0
def load_enet(sess, checkpoint_dir, input_image, batch_size, num_classes):
    checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

    num_initial_blocks = 1
    skip_connections = False
    stage_two_repeat = 2

    with slim.arg_scope(ENet_arg_scope()):
        logits, _ = ENet(input_image,
                         num_classes=12,
                         batch_size=batch_size,
                         is_training=True,
                         reuse=None,
                         num_initial_blocks=num_initial_blocks,
                         stage_two_repeat=stage_two_repeat,
                         skip_connections=skip_connections)

    variables_to_restore = slim.get_variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess, checkpoint)
    graph = tf.get_default_graph()

    last_prelu = graph.get_tensor_by_name('ENet/bottleneck5_1_last_prelu:0')
    output = slim.conv2d_transpose(
        last_prelu,
        num_classes, [2, 2],
        stride=2,
        weights_initializer=initializers.xavier_initializer(),
        scope='Semantic/transfer_layer/conv2d_transpose')

    probabilities = tf.nn.softmax(
        output, name='Semantic/transfer_layer/logits_to_softmax')

    with tf.variable_scope('', reuse=True):
        weight = tf.get_variable(
            'Semantic/transfer_layer/conv2d_transpose/weights')
        bias = tf.get_variable(
            'Semantic/transfer_layer/conv2d_transpose/biases')
        sess.run([weight.initializer, bias.initializer])

    return output, probabilities
def load_enet(sess, checkpoint_dir, input_image, batch_size):
    checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

    num_initial_blocks = 1
    skip_connections = False
    stage_two_repeat = 2

    with slim.arg_scope(ENet_arg_scope()):
        _, _ = ENet(input_image,
                    num_classes=12,
                    batch_size=batch_size,
                    is_training=True,
                    reuse=None,
                    num_initial_blocks=num_initial_blocks,
                    stage_two_repeat=stage_two_repeat,
                    skip_connections=skip_connections)

    variables_to_restore = slim.get_variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess, checkpoint)

    graph = tf.get_default_graph()
    last_prelu = graph.get_tensor_by_name('ENet/bottleneck5_1_last_prelu:0')
    return last_prelu
예제 #5
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TEST BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        input_queue = tf.train.slice_input_producer([images], shuffle=False)

        #Decode the image and annotation raw content
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)
        preprocessed_image = preprocess(image, None, image_height, image_width)

        images = tf.train.batch([preprocessed_image],
                                batch_size=batch_size,
                                allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(images,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(probabilities, -1)

        #Create the global step and an increment op for monitoring
        global_step = get_or_create_global_step()
        global_step_op = tf.assign(
            global_step, global_step + 1
        )  #no apply_gradient method so manually increasing the global_step

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session
        with sv.managed_session() as sess:

            #Save the images
            if save_images:
                if not os.path.exists(photo_dir):
                    os.mkdir(photo_dir)

                for step in range(int(num_steps_per_epoch)):
                    # Compute summaries every 10 steps and continue evaluating
                    time_run = time.time()
                    predictions_val = sess.run([predictions])
                    time_run_end = time.time()
                    predictions_val_tuple = predictions_val[0]

                    print('totally cost (second)', time_run_end - time_run)

                    for i in range(predictions_val_tuple.shape[0]):
                        predicted_annotation = predictions_val_tuple[i]

                        # plt.subplot(1, 2, 1)
                        plt.imshow(predicted_annotation)
                        # plt.subplot(1, 2, 2)
                        # plt.imshow(img)
                        plt.savefig(photo_dir + "/image_" +
                                    str(image_files[step * num_epochs +
                                                    i])[15:])
예제 #6
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TRAINING BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        annotations = tf.convert_to_tensor(annotation_files)
        input_queue = tf.train.slice_input_producer(
            [images,
             annotations])  #Slice_input producer shuffles the data by default.

        #Decode the image and annotation raw content
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)
        annotation = tf.read_file(input_queue[1])
        annotation = tf.image.decode_image(annotation)

        #preprocess and batch up the image and annotation
        preprocessed_image, preprocessed_annotation = preprocess(
            image, annotation, image_height, image_width)
        images, annotations = tf.train.batch(
            [preprocessed_image, preprocessed_annotation],
            batch_size=batch_size,
            allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope(weight_decay=weight_decay)):
            logits, probabilities = ENet(images,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        #perform one-hot-encoding on the ground truth annotation to get same shape as the logits
        annotations = tf.reshape(annotations,
                                 shape=[batch_size, image_height, image_width])
        annotations_ohe = tf.one_hot(annotations, num_classes, axis=-1)

        #Actually compute the loss
        loss = weighted_cross_entropy(logits=logits,
                                      onehot_labels=annotations_ohe,
                                      class_weights=class_weights)
        total_loss = tf.losses.get_total_loss()

        #Create the global step for monitoring the learning_rate and training.
        global_step = get_or_create_global_step()

        #Define your exponentially decaying learning rate
        lr = tf.train.exponential_decay(learning_rate=initial_learning_rate,
                                        global_step=global_step,
                                        decay_steps=decay_steps,
                                        decay_rate=learning_rate_decay_factor,
                                        staircase=True)

        #Now we can define the optimizer that takes on the learning rate
        optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon)

        #Create the train_op.
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(probabilities, -1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
            predictions, annotations)
        mean_IOU, mean_IOU_update = tf.contrib.metrics.streaming_mean_iou(
            predictions=predictions,
            labels=annotations,
            num_classes=num_classes)
        metrics_op = tf.group(accuracy_update, mean_IOU_update)

        #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently.
        def train_step(sess, train_op, global_step, metrics_op):
            '''
            Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step
            '''
            #Check the time for each sess run
            start_time = time.time()
            total_loss, global_step_count, accuracy_val, mean_IOU_val, _ = sess.run(
                [train_op, global_step, accuracy, mean_IOU, metrics_op])
            time_elapsed = time.time() - start_time

            #Run the logging to show some results
            logging.info(
                'global step %s: loss: %.4f (%.2f sec/step)    Current Streaming Accuracy: %.4f    Current Mean IOU: %.4f',
                global_step_count, total_loss, time_elapsed, accuracy_val,
                mean_IOU_val)

            return total_loss, accuracy_val, mean_IOU_val

        #================VALIDATION BRANCH========================
        #Load the files into one input queue
        images_val = tf.convert_to_tensor(image_val_files)
        annotations_val = tf.convert_to_tensor(annotation_val_files)
        input_queue_val = tf.train.slice_input_producer(
            [images_val, annotations_val])

        #Decode the image and annotation raw content
        image_val = tf.read_file(input_queue_val[0])
        image_val = tf.image.decode_jpeg(image_val, channels=3)
        annotation_val = tf.read_file(input_queue_val[1])
        annotation_val = tf.image.decode_png(annotation_val)

        #preprocess and batch up the image and annotation
        preprocessed_image_val, preprocessed_annotation_val = preprocess(
            image_val, annotation_val, image_height, image_width)
        images_val, annotations_val = tf.train.batch(
            [preprocessed_image_val, preprocessed_annotation_val],
            batch_size=eval_batch_size,
            allow_smaller_final_batch=True)

        with slim.arg_scope(ENet_arg_scope(weight_decay=weight_decay)):
            logits_val, probabilities_val = ENet(
                images_val,
                num_classes,
                batch_size=eval_batch_size,
                is_training=True,
                reuse=True,
                num_initial_blocks=num_initial_blocks,
                stage_two_repeat=stage_two_repeat,
                skip_connections=skip_connections)

        #perform one-hot-encoding on the ground truth annotation to get same shape as the logits
        annotations_val = tf.reshape(
            annotations_val,
            shape=[eval_batch_size, image_height, image_width])
        annotations_ohe_val = tf.one_hot(annotations_val, num_classes, axis=-1)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded. ----> Should we use OHE instead?
        predictions_val = tf.argmax(probabilities_val, -1)
        accuracy_val, accuracy_val_update = tf.contrib.metrics.streaming_accuracy(
            predictions_val, annotations_val)
        mean_IOU_val, mean_IOU_val_update = tf.contrib.metrics.streaming_mean_iou(
            predictions=predictions_val,
            labels=annotations_val,
            num_classes=num_classes)
        metrics_op_val = tf.group(accuracy_val_update, mean_IOU_val_update)

        #Create an output for showing the segmentation output of validation images
        segmentation_output_val = tf.cast(predictions_val, dtype=tf.float32)
        segmentation_output_val = tf.reshape(
            segmentation_output_val, shape=[-1, image_height, image_width, 1])
        segmentation_ground_truth_val = tf.cast(annotations_val,
                                                dtype=tf.float32)
        segmentation_ground_truth_val = tf.reshape(
            segmentation_ground_truth_val,
            shape=[-1, image_height, image_width, 1])

        def eval_step(sess, metrics_op):
            '''
            Simply takes in a session, runs the metrics op and some logging information.
            '''
            start_time = time.time()
            _, accuracy_value, mean_IOU_value = sess.run(
                [metrics_op, accuracy_val, mean_IOU_val])
            time_elapsed = time.time() - start_time

            #Log some information
            logging.info(
                '---VALIDATION--- Validation Accuracy: %.4f    Validation Mean IOU: %.4f    (%.2f sec/step)',
                accuracy_value, mean_IOU_value, time_elapsed)

            return accuracy_value, mean_IOU_value

        #=====================================================

        #Now finally create all the summaries you need to monitor and group them into one summary op.
        tf.summary.scalar('Monitor/Total_Loss', total_loss)
        tf.summary.scalar('Monitor/validation_accuracy', accuracy_val)
        tf.summary.scalar('Monitor/training_accuracy', accuracy)
        tf.summary.scalar('Monitor/validation_mean_IOU', mean_IOU_val)
        tf.summary.scalar('Monitor/training_mean_IOU', mean_IOU)
        tf.summary.scalar('Monitor/learning_rate', lr)
        tf.summary.image('Images/Validation_original_image',
                         images_val,
                         max_outputs=1)
        tf.summary.image('Images/Validation_segmentation_output',
                         segmentation_output_val,
                         max_outputs=1)
        tf.summary.image('Images/Validation_segmentation_ground_truth',
                         segmentation_ground_truth_val,
                         max_outputs=1)
        my_summary_op = tf.summary.merge_all()

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir, summary_op=None, init_fn=None)

        # Run the managed session
        with sv.managed_session() as sess:
            for step in xrange(int(num_steps_per_epoch * num_epochs)):
                #At the start of every epoch, show the vital information:
                if step % num_batches_per_epoch == 0:
                    logging.info('Epoch %s/%s',
                                 step / num_batches_per_epoch + 1, num_epochs)
                    learning_rate_value = sess.run([lr])
                    logging.info('Current Learning Rate: %s',
                                 learning_rate_value)

                #Log the summaries every 10 steps or every end of epoch, which ever lower.
                if step % min(num_steps_per_epoch, 10) == 0:
                    loss, training_accuracy, training_mean_IOU = train_step(
                        sess, train_op, sv.global_step, metrics_op=metrics_op)

                    #Check the validation data only at every third of an epoch
                    if step % (num_steps_per_epoch / 3) == 0:
                        for i in xrange(
                                len(image_val_files) / eval_batch_size):
                            validation_accuracy, validation_mean_IOU = eval_step(
                                sess, metrics_op_val)

                    summaries = sess.run(my_summary_op)
                    sv.summary_computed(sess, summaries)

                #If not, simply run the training step
                else:
                    loss, training_accuracy, training_mean_IOU = train_step(
                        sess, train_op, sv.global_step, metrics_op=metrics_op)

            #We log the final training loss
            logging.info('Final Loss: %s', loss)
            logging.info('Final Training Accuracy: %s', training_accuracy)
            logging.info('Final Training Mean IOU: %s', training_mean_IOU)
            logging.info('Final Validation Accuracy: %s', validation_accuracy)
            logging.info('Final Validation Mean IOU: %s', validation_mean_IOU)

            #Once all the training has been done, save the log files and checkpoint model
            logging.info('Finished training! Saving model to disk now.')
            sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

            if save_images:
                if not os.path.exists(photo_dir):
                    os.mkdir(photo_dir)

                #Plot the predictions - check validation images only
                logging.info('Saving the images now...')
                predictions_value, annotations_value = sess.run(
                    [predictions_val, annotations_val])

                for i in xrange(eval_batch_size):
                    predicted_annotation = predictions_value[i]
                    annotation = annotations_value[i]

                    plt.subplot(1, 2, 1)
                    plt.imshow(predicted_annotation)
                    plt.subplot(1, 2, 2)
                    plt.imshow(annotation)
                    plt.savefig(photo_dir + "/image_" + str(i))
예제 #7
0

with tf.Graph().as_default() as graph:
    images_tensor = tf.train.string_input_producer(images_list, shuffle=False)
    reader = tf.WholeFileReader()
    key, image_tensor = reader.read(images_tensor)
    image = tf.image.decode_png(image_tensor, channels=3)
    # image = tf.image.resize_image_with_crop_or_pad(image, 360, 480)
    # image = tf.cast(image, tf.float32)
    image = preprocess(image)
    images = tf.train.batch([image],
                            batch_size=10,
                            allow_smaller_final_batch=True)

    #Create the model inference
    with slim.arg_scope(ENet_arg_scope()):
        logits, probabilities = ENet(images,
                                     num_classes=12,
                                     batch_size=10,
                                     is_training=True,
                                     reuse=None,
                                     num_initial_blocks=num_initial_blocks,
                                     stage_two_repeat=stage_two_repeat,
                                     skip_connections=skip_connections)

    variables_to_restore = slim.get_variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    def restore_fn(sess):
        return saver.restore(sess, checkpoint)
예제 #8
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TEST BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        annotations = tf.convert_to_tensor(annotation_files)
        input_queue = tf.train.slice_input_producer([images, annotations])

        #Decode the image and annotation raw content
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)
        annotation = tf.read_file(input_queue[1])
        annotation = tf.image.decode_image(annotation)

        #preprocess and batch up the image and annotation
        preprocessed_image, preprocessed_annotation = preprocess(
            image, annotation, image_height, image_width)
        images, annotations = tf.train.batch(
            [preprocessed_image, preprocessed_annotation],
            batch_size=batch_size,
            allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(images,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #perform one-hot-encoding on the ground truth annotation to get same shape as the logits
        annotations = tf.reshape(annotations,
                                 shape=[batch_size, image_height, image_width])
        annotations_ohe = tf.one_hot(annotations, num_classes, axis=-1)
        annotations = tf.cast(annotations, tf.int64)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(probabilities, -1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
            predictions, annotations)
        mean_IOU, mean_IOU_update = tf.contrib.metrics.streaming_mean_iou(
            predictions=predictions,
            labels=annotations,
            num_classes=num_classes)
        per_class_accuracy, per_class_accuracy_update = tf.metrics.mean_per_class_accuracy(
            labels=annotations,
            predictions=predictions,
            num_classes=num_classes)
        metrics_op = tf.group(accuracy_update, mean_IOU_update,
                              per_class_accuracy_update)

        #Create the global step and an increment op for monitoring
        global_step = get_or_create_global_step()
        global_step_op = tf.assign(
            global_step, global_step + 1
        )  #no apply_gradient method so manually increasing the global_step

        #Create a evaluation step function
        def eval_step(sess, metrics_op, global_step):
            '''
            Simply takes in a session, runs the metrics op and some logging information.
            '''
            start_time = time.time()
            _, global_step_count, accuracy_value, mean_IOU_value, per_class_accuracy_value = sess.run(
                [
                    metrics_op, global_step_op, accuracy, mean_IOU,
                    per_class_accuracy
                ])
            time_elapsed = time.time() - start_time

            #Log some information
            logging.info(
                'Global Step %s: Streaming Accuracy: %.4f     Streaming Mean IOU: %.4f     Per-class Accuracy: %.4f (%.2f sec/step)',
                global_step_count, accuracy_value, mean_IOU_value,
                per_class_accuracy_value, time_elapsed)

            return accuracy_value, mean_IOU_value, per_class_accuracy_value

        #Create your summaries
        tf.summary.scalar('Monitor/test_accuracy', accuracy)
        tf.summary.scalar('Monitor/test_mean_per_class_accuracy',
                          per_class_accuracy)
        tf.summary.scalar('Monitor/test_mean_IOU', mean_IOU)
        my_summary_op = tf.summary.merge_all()

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session
        with sv.managed_session() as sess:
            for step in range(int(num_steps_per_epoch * num_epochs)):
                #print vital information every start of the epoch as always
                if step % num_batches_per_epoch == 0:
                    accuracy_value, mean_IOU_value = sess.run(
                        [accuracy, mean_IOU])
                    logging.info('Epoch: %s/%s',
                                 step / num_batches_per_epoch + 1, num_epochs)
                    logging.info('Current Streaming Accuracy: %.4f',
                                 accuracy_value)
                    logging.info('Current Streaming Mean IOU: %.4f',
                                 mean_IOU_value)

                #Compute summaries every 10 steps and continue evaluating
                if step % 10 == 0:
                    test_accuracy, test_mean_IOU, test_per_class_accuracy = eval_step(
                        sess,
                        metrics_op=metrics_op,
                        global_step=sv.global_step)
                    summaries = sess.run(my_summary_op)
                    sv.summary_computed(sess, summaries)

                #Otherwise just run as per normal
                else:
                    test_accuracy, test_mean_IOU, test_per_class_accuracy = eval_step(
                        sess,
                        metrics_op=metrics_op,
                        global_step=sv.global_step)

            #At the end of all the evaluation, show the final accuracy
            logging.info('Final Streaming Accuracy: %.4f', test_accuracy)
            logging.info('Final Mean IOU: %.4f', test_mean_IOU)
            logging.info('Final Per Class Accuracy %.4f',
                         test_per_class_accuracy)

            #Show end of evaluation
            logging.info('Finished evaluating!')

            #Save the images
            if save_images:
                if not os.path.exists(photo_dir):
                    os.mkdir(photo_dir)

                #Save the image visualizations for the first 10 images.
                logging.info('Saving the images now...')
                predictions_val, annotations_val = sess.run(
                    [predictions, annotations])

                for i in range(10):
                    predicted_annotation = predictions_val[i]
                    annotation = annotations_val[i]

                    plt.subplot(1, 2, 1)
                    plt.imshow(predicted_annotation)
                    plt.subplot(1, 2, 2)
                    plt.imshow(annotation)
                    plt.savefig(photo_dir + "/image_" + str(i))
예제 #9
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TEST BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        input_queue = tf.train.slice_input_producer([images], shuffle=False)

        #Decode the image and annotation raw content
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)
        preprocessed_image = preprocess(image, None, image_height, image_width)

        images = tf.train.batch([preprocessed_image],
                                batch_size=batch_size,
                                allow_smaller_final_batch=True)

        images_placeholder = tf.placeholder(
            tf.float32, [None, image_height, image_width, 3], name='rgb_image')
        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(images_placeholder,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session   sv.managed_session()   ENet
        with sv.managed_session() as sess:

            #Save the images
            if save_images:
                if not os.path.exists(photo_dir):
                    os.mkdir(photo_dir)

                # while True:
                #     image_files = sorted(
                #         [os.path.join(dataset_dir, 'test', file) for file in os.listdir(dataset_dir + "/test") if
                #          file.endswith('.png')])
                #     if len(image_files) > 0:
                #         # Load the files into one input queue
                #
                #         for image_name in image_files:
                #             time_run = time.time()
                #             image = cv2.imread(image_name)
                #             image_resize = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
                #             image_resize_float32 = image_resize.astype('float32')
                #             batch_x = np.zeros([batch_size, image_height, image_width, 3])
                #             batch_x_float32 = batch_x.astype('float32')
                #             batch_x_float32[0] = image_resize_float32
                #             feed_dict = {images_placeholder: batch_x_float32}
                #             probabilities_numpy = sess.run(probabilities, feed_dict=feed_dict)
                #             predictions_val = np.argmax(probabilities_numpy, -1)
                #
                #             time_run_end = time.time()
                #             print('totally cost (second)', time_run_end - time_run)
                #
                #             for i in range(batch_size):  # predictions_val_tuple.shape[0]
                #                 predicted_image = predictions_val[i]
                #                 plt.imshow(predicted_image)
                #                 plt.savefig(photo_dir + "/image_" + str(image_name)[15:])
                #                 # cv2.imwrite(photo_dir + "/image_" + str(image_files[step * num_epochs + i])[15:], predicted_image)
                #         #delete images_files
                #
                #     else:
                #         continue

                # images = np.zeros([1, 360, 480, 3], dtype=np.float32)

                for step in range(int(num_steps_per_epoch)):
                    # Compute summaries every 10 steps and continue evaluating
                    time_run = time.time()

                    image_numpy = images.eval(session=sess)
                    feed_dict = {images_placeholder: image_numpy}
                    probabilities_numpy = sess.run(probabilities,
                                                   feed_dict=feed_dict)
                    predictions_val = np.argmax(probabilities_numpy, -1)

                    time_run_end = time.time()

                    print('totally cost (second)', time_run_end - time_run)

                    for i in range(
                            batch_size):  #predictions_val_tuple.shape[0]
                        predicted_image = predictions_val[i]
                        # plt.imshow(predicted_image)
                        plt.subplot(1, 2, 1)
                        plt.imshow(predicted_image)
                        plt.subplot(1, 2, 2)
                        original_image = cv2.imread(image_files[step])
                        plt.imshow(original_image)
                        plt.savefig(photo_dir + "/image_" +
                                    str(image_files[step])[15:])
예제 #10
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TEST BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        annotations = tf.convert_to_tensor(annotation_files)
        input_queue = tf.train.slice_input_producer([images, annotations])

        #Decode the image and annotation raw content
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)
        annotation = tf.read_file(input_queue[1])
        annotation = tf.image.decode_image(annotation)

        #preprocess and batch up the image and annotation
        preprocessed_image, preprocessed_annotation = preprocess(
            image, annotation, image_height, image_width)
        images, annotations = tf.train.batch(
            [preprocessed_image, preprocessed_annotation],
            batch_size=batch_size,
            allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            if (network == 'ENet'):
                print('Building the network: ', network)
                logits, probabilities = ENet(
                    images,
                    num_classes,
                    batch_size=batch_size,
                    is_training=is_training,
                    reuse=None,
                    num_initial_blocks=num_initial_blocks,
                    stage_two_repeat=stage_two_repeat,
                    skip_connections=skip_connections)

            if (network == 'ENet_Small'):
                print('Building the network: ', network)
                logits, probabilities = ENet_Small(
                    images,
                    num_classes,
                    batch_size=batch_size,
                    is_training=is_training,
                    reuse=None,
                    num_initial_blocks=num_initial_blocks,
                    skip_connections=skip_connections)

            if (network == 'ErfNet'):
                print('Building the network: ', network)
                logits, probabilities = ErfNet(images,
                                               num_classes,
                                               batch_size=batch_size,
                                               is_training=is_training,
                                               reuse=None)

            if (network == 'ErfNet_Small'):
                print('Building the network: ', network)
                logits, probabilities = ErfNet_Small(images,
                                                     num_classes,
                                                     batch_size=batch_size,
                                                     is_training=is_training,
                                                     reuse=None)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #perform one-hot-encoding on the ground truth annotation to get same shape as the logits
        annotations = tf.reshape(annotations,
                                 shape=[batch_size, image_height, image_width])
        annotations_ohe = tf.one_hot(annotations, num_classes, axis=-1)
        annotations = tf.cast(annotations, tf.int64)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(probabilities, -1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
            predictions, annotations)
        mean_IOU, mean_IOU_update = tf.contrib.metrics.streaming_mean_iou(
            predictions=predictions,
            labels=annotations,
            num_classes=num_classes)
        per_class_accuracy, per_class_accuracy_update = tf.metrics.mean_per_class_accuracy(
            labels=annotations,
            predictions=predictions,
            num_classes=num_classes)
        metrics_op = tf.group(accuracy_update, mean_IOU_update,
                              per_class_accuracy_update)

        #Create the global step and an increment op for monitoring
        global_step = get_or_create_global_step()
        global_step_op = tf.assign(
            global_step, global_step + 1
        )  #no apply_gradient method so manually increasing the global_step

        #Create a evaluation step function
        def eval_step(sess, metrics_op, global_step):
            '''
            Simply takes in a session, runs the metrics op and some logging information.
            '''
            _, global_step_count, accuracy_value, mean_IOU_value, per_class_accuracy_value = sess.run(
                [
                    metrics_op, global_step_op, accuracy, mean_IOU,
                    per_class_accuracy
                ])

            #Log some information
            logging.info(
                'Global Step %s: Streaming Accuracy: %.4f     Streaming Mean IOU: %.4f     Per-class Accuracy: %.4f (%.2f sec/step)',
                global_step_count, accuracy_value, mean_IOU_value,
                per_class_accuracy_value)

            return accuracy_value, mean_IOU_value, per_class_accuracy_value

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session
        with sv.managed_session() as sess:
            start_time = time.time()
            for step in range(int(num_steps_per_epoch * num_epochs)):
                _, global_step_count, test_accuracy, test_mean_IOU, test_per_class_accuracy = sess.run(
                    [
                        metrics_op, global_step_op, accuracy, mean_IOU,
                        per_class_accuracy
                    ])

            time_elapsed = time.time() - start_time

            #At the end of all the evaluation, show the final accuracy
            logging.info('Final Streaming Accuracy: %.4f', test_accuracy)
            logging.info('Final Mean IOU: %.4f', test_mean_IOU)
            logging.info('Final Per Class Accuracy %.4f',
                         test_per_class_accuracy)
            logging.info('Time Elapsed %.4f', time_elapsed)
            logging.info('FPS %.4f',
                         (num_steps_per_epoch * num_epochs) / time_elapsed)

            #Show end of evaluation
            logging.info('Finished evaluating!')
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TEST BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        annotations = tf.convert_to_tensor(annotation_files)
        input_queue = tf.train.slice_input_producer([images, annotations],
                                                    shuffle=False)

        #Decode the image and annotation raw content
        filename = input_queue[0]
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)
        annotation = tf.read_file(input_queue[1])
        annotation = tf.image.decode_image(annotation)

        #preprocess and batch up the image and annotation
        preprocessed_image, preprocessed_annotation = preprocess_ori(
            image, annotation, image_height, image_width)
        images, annotations, filenames = tf.train.batch(
            [preprocessed_image, preprocessed_annotation, filename],
            batch_size=batch_size,
            allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(images,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #perform one-hot-encoding on the ground truth annotation to get same shape as the logits
        annotations = tf.reshape(annotations,
                                 shape=[batch_size, image_height, image_width])
        annotations_ohe = one_hot(annotations, batch_size, dataset)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(probabilities, -1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
            predictions, annotations)
        mean_IOU, mean_IOU_update = tf.contrib.metrics.streaming_mean_iou(
            predictions=predictions,
            labels=annotations,
            num_classes=num_classes)
        per_class_accuracy, per_class_accuracy_update = tf.metrics.mean_per_class_accuracy(
            labels=annotations,
            predictions=predictions,
            num_classes=num_classes)
        metrics_op = tf.group(accuracy_update, mean_IOU_update,
                              per_class_accuracy_update)

        #Create the global step and an increment op for monitoring
        global_step = get_or_create_global_step()
        global_step_op = tf.assign(
            global_step, global_step + 1
        )  #no apply_gradient method so manually increasing the global_step

        #Create a evaluation step function
        def eval_step(sess, metrics_op, global_step):
            '''
            Simply takes in a session, runs the metrics op and some logging information.
            '''
            _, global_step_count, accuracy_value, mean_IOU_value, per_class_accuracy_value = sess.run(
                [
                    metrics_op, global_step_op, accuracy, mean_IOU,
                    per_class_accuracy
                ])

            start_time = time.time()
            predictions_val, filename_val = sess.run([predictions, filenames])
            time_elapsed = time.time() - start_time

            #Log some information
            logging.info(
                'Global Step %s: Streaming Accuracy: %.4f     Streaming Mean IOU: %.4f     Per-class Accuracy: %.4f    %.2f(sec/step)  %.2f (fps)',
                global_step_count, accuracy_value, mean_IOU_value,
                per_class_accuracy_value, time_elapsed / batch_size,
                batch_size / time_elapsed)

            #Save the images
            if save_images:
                if not os.path.exists(photo_dir):
                    os.mkdir(photo_dir)

                #Segmentation
                for i in xrange(batch_size):
                    segmentation = produce_color_segmentation(
                        predictions_val[i], image_height, image_width, dataset)
                    filename = filename_val[i].split('/')
                    filename = filename[len(filename) - 1]
                    filename = photo_dir + "/trainResult_" + filename
                    cv2.imwrite(filename, segmentation)

            return accuracy_value, mean_IOU_value, per_class_accuracy_value, time_elapsed

        #Create your summaries
        tf.summary.scalar('Monitor/test_accuracy', accuracy)
        tf.summary.scalar('Monitor/test_mean_per_class_accuracy',
                          per_class_accuracy)
        tf.summary.scalar('Monitor/test_mean_IOU', mean_IOU)
        my_summary_op = tf.summary.merge_all()

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session
        with sv.managed_session() as sess:

            total_time = 0
            for step in range(int(num_steps_per_epoch)):
                #Compute summaries every 10 steps and continue evaluating
                if step % 10 == 0:
                    test_accuracy, test_mean_IOU, test_per_class_accuracy, time_elapsed = eval_step(
                        sess,
                        metrics_op=metrics_op,
                        global_step=sv.global_step)
                    summaries = sess.run(my_summary_op)
                    sv.summary_computed(sess, summaries)

                #Otherwise just run as per normal
                else:
                    test_accuracy, test_mean_IOU, test_per_class_accuracy, time_elapsed = eval_step(
                        sess,
                        metrics_op=metrics_op,
                        global_step=sv.global_step)

                total_time = total_time + time_elapsed

            #At the end of all the evaluation, show the final accuracy
            logging.info('Final Streaming Accuracy: %.4f', test_accuracy)
            logging.info('Final Mean IOU: %.4f', test_mean_IOU)
            logging.info('Final Per Class Accuracy: %.4f',
                         test_per_class_accuracy)
            logging.info('Average Speed: %.4f fps',
                         batch_size * (num_steps_per_epoch - 1) / total_time)

            #Show end of evaluation
            logging.info('Finished evaluating!')
예제 #12
0
def run():
    with tf.Graph().as_default() as graph:

        image_name = sys.argv[1]
        images = tf.convert_to_tensor([image_name])
        input_queue = tf.train.slice_input_producer([images])

        #Decode the image and annotation raw content
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)

        #Create the model inference
        image = preprocess(image, None, image_height, image_width)

        image = tf.train.batch([image],
                               batch_size=batch_size,
                               allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(image,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

    # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

    #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.

        predictions = tf.argmax(probabilities, -1)

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=logdir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session
        with sv.managed_session() as sess:
            logging.info('Saving the images now...')
            predictions_val = sess.run(predictions)
            img = predictions_val[0]
            plt.imshow(img)
            plt.axis('off'), plt.xticks([]), plt.yticks([])
            plt.tight_layout()
            plt.subplots_adjust(left=0,
                                bottom=0,
                                right=1,
                                top=1,
                                hspace=0,
                                wspace=0)
            out_file_name = "./output/result" + sys.argv[1].split('.')[0]
            plt.savefig(out_file_name, bbox_inces='tight', pad_inches=0)
            print(out_file_name)
예제 #13
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        images_placeholder = tf.placeholder(tf.float32, [batch_size, image_height, image_width, 3], name='rgb_image')
        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(images_placeholder,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)
        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir = logdir, summary_op = None, init_fn=restore_fn)


        #Run the managed session   sv.managed_session()   ENet
        with sv.managed_session() as sess:

            #Save the images  our_video
            if save_images:
                if not os.path.exists(photo_dir):
                    os.mkdir(photo_dir)
                kernel = np.ones((10, 10), np.uint8)
                while True:
                    image_files = sorted(
                        [os.path.join(dataset_dir, 'our_video', file) for file in os.listdir(dataset_dir + "/our_video") if
                         file.endswith('.jpg')])
                    if len(image_files) > 0:
                        # Load the files into one input queue

                        time_all_image_start = time.time()
                        for image_name in image_files:
                            time_run = time.time()
                            image = cv2.imread(image_name)
                            image_resize = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
                            image_resize_float32 = image_resize.astype('float32')
                            image_resize_float32 = image_resize_float32[:, :, ::-1]
                            image_resize_float32 = image_resize_float32 / 255.0
                            batch_x = np.zeros([batch_size, image_height, image_width, 3])
                            batch_x_float32 = batch_x.astype('float32')
                            batch_x_float32[0] = image_resize_float32
                            feed_dict = {images_placeholder: batch_x_float32}
                            probabilities_numpy = sess.run(probabilities, feed_dict=feed_dict)
                            predictions_val = np.argmax(probabilities_numpy, -1)

                            time_Preprocessing_and_Predict_end = time.time()
                            print('One image Preprocessing_and_Predict cost (second)', time_Preprocessing_and_Predict_end - time_run)


                            for i in range(batch_size):  # predictions_val_tuple.shape[0]
                                predicted_image = predictions_val[i]
                                # plt.imshow(predicted_image)
                                time_morphologyEx_start = time.time()
                                predicted_image_uint8 = predicted_image.astype('uint8')
                                predicted_image_closing = cv2.morphologyEx(predicted_image_uint8, cv2.MORPH_OPEN, kernel)
                                color_mask = np.ones(image_resize.shape, np.float)
                                time_morphologyEx_end = time.time()
                                print('morphologyEx cost (second)', time_morphologyEx_end - time_morphologyEx_start)

                                time_color_mask_start = time.time()
                                # for x in range(predicted_image.shape[0]):
                                #     for y in range(predicted_image.shape[1]):
                                #         color_mask[x, y, :] = gray_convert_color(predicted_image_closing[x, y])

                                # for label in range(11):
                                for label in np.unique(predicted_image_closing):
                                    xy = np.where(predicted_image_closing == label)
                                    color_mask[xy[0], xy[1], :] = gray_convert_color(label)

                                color_mask_uint8 = (color_mask * 255).astype('uint8')
                                time_color_mask_end = time.time()
                                print('color_mask cost (second)', time_color_mask_end - time_color_mask_start)

                                # ###开操作去噪声
                                # plt.subplot(1, 2, 1)
                                # plt.imshow(predicted_image)
                                # plt.subplot(1, 2, 2)
                                # # original_image = cv2.imread(image_name)
                                # plt.imshow(predicted_image_closing)

                                time_addWeighted_start = time.time()
                                overlapping = cv2.addWeighted(image_resize, 0.2, color_mask_uint8, 0.8, 0)
                                time_addWeighted_end = time.time()
                                print('addWeighted cost (second)', time_addWeighted_end - time_addWeighted_start)

                                # plt.imshow(overlapping)

                                time_result_save_start = time.time()
                                # cv2.imwrite(photo_dir + "/image_" + str(image_name)[10:], overlapping)  #str(image_name)[15:]
                                cv2.imwrite(photo_dir + "/" + str(image_name)[20:], overlapping)
                                time_result_save_end = time.time()
                                print('result_save cost (second)', time_result_save_end - time_result_save_start)


                                # plt.savefig(photo_dir + "/image_" + str(image_name)[15:])
                                # cv2.imwrite(photo_dir + "/image_" + str(image_files[step * num_epochs + i])[15:], predicted_image)

                            time_run_end = time.time()
                            print('One image cost (second)', time_run_end - time_run)

                        time_all_image_end = time.time()
                        logging.info('There are %.4f image in all', len(image_files))
                        print('totally cost (second)', time_all_image_end - time_all_image_start)

                        #delete images_files
                        # python删除文件夹下所有文件

                    else:
                        print('There is no new images to deal with')
                        continue
예제 #14
0
def run():
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        #===================TEST BRANCH=======================
        #Load the files into one input queue
        images = tf.convert_to_tensor(image_files)
        input_queue = tf.train.slice_input_producer([images])

        #Decode the image and annotation raw content
        filename = input_queue[0]
        image = tf.read_file(input_queue[0])
        image = tf.image.decode_image(image, channels=3)

        #preprocess and batch up the image and annotation
        preprocessed_image = preprocess_ori(image, None, image_height,
                                            image_width)
        images, filenames = tf.train.batch([preprocessed_image, filename],
                                           batch_size=batch_size,
                                           allow_smaller_final_batch=True)

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope()):
            logits, probabilities = ENet(images,
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True,
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

        # Set up the variables to restore and restoring function from a saver.
        exclude = []
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(probabilities, -1)

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir=photo_dir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        #Run the managed session
        with sv.managed_session() as sess:

            #Save the images
            if not os.path.exists(photo_dir):
                os.mkdir(photo_dir)

            #Segmentation
            total_time = 0
            logging.info('Total Steps: %d', int(num_steps_per_epoch))
            for step in range(int(num_steps_per_epoch)):
                start_time = time.time()
                predictions_val, filename_val = sess.run(
                    [predictions, filenames])
                time_elapsed = time.time() - start_time
                logging.info('step %d  %.2f(sec/step)  %.2f (fps)', step,
                             time_elapsed / batch_size,
                             batch_size / time_elapsed)
                total_time = total_time + time_elapsed

                if save_images:
                    for i in xrange(batch_size):
                        segmentation = produce_color_segmentation(
                            predictions_val[i], image_height, image_width,
                            color)
                        filename = filename_val[i].split('/')
                        filename = filename[len(filename) - 1]
                        filename = photo_dir + "/trainResult_" + filename
                        cv2.imwrite(filename, segmentation)

            logging.info('Average speed: %.2f fps',
                         len(image_files) / total_time)