예제 #1
0
def main():
    resize_shape = 64
    print "data is loading..."
    train_X, train_Y, test_X, test_Y = load_data(resize_shape)
    print "data is loaded"
    print "feature engineering..."
    learning_rate = 0.01
    training_iters = 100000
    batch_size = 128
    display_step = 10

    # Network Parameters
    n_input = resize_shape*resize_shape # MNIST data input (img shape: 28*28)
    n_classes = 62 # MNIST total classes (0-9 digits)
    dropout = 0.5 # Dropout, probability to keep units

    with tf.Session() as sess:
        cnn = CNN(sess, learning_rate, training_iters, batch_size, display_step, n_input, n_classes, dropout,resize_shape)
        train_X = cnn.inference(train_X)
        test_X = cnn.inference(test_X)

    print "feature engineering is complete"

    print 'training phase'
    clf = svm.LinearSVC().fit(train_X, train_Y)
    print 'test phase'
    predicts = clf.predict(test_X)

    # measure function
    print 'measure phase'
    print confusion_matrix(test_Y, predicts)
    print f1_score(test_Y, predicts, average=None)
    print precision_score(test_Y, predicts, average=None)
    print recall_score(test_Y, predicts, average=None)
    print accuracy_score(test_Y, predicts)
예제 #2
0
def main():
  with tf.Graph().as_default():
    cnn = CNN(image_size=FLAGS.image_size, class_count=len(Channel))
    images, labels = load_data(
      'data/test/data.csv',
      batch_size=FLAGS.batch_size,
      image_size=FLAGS.image_size,
      class_count=len(Channel),
      shuffle=False)
    keep_prob = tf.placeholder(tf.float32)

    logits = cnn.inference(images, keep_prob)
    accuracy = cnn.accuracy(logits, labels)

    saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
      sess.run(init_op)
      saver.restore(sess, os.path.join(LOG_DIR, 'model.ckpt'))
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      accuracy_value = sess.run(accuracy, feed_dict={keep_prob: 0.5})

      print(f'test accuracy: {accuracy_value}')

      coord.request_stop()
      coord.join(threads)
예제 #3
0
def main(imagepath):
    cnn = CNN(image_size=FLAGS.image_size, class_count=len(Channel))
    image = load_image(imagepath, image_size=FLAGS.image_size)
    keep_prob = tf.placeholder(tf.float32)
    logits = cnn.inference(image, keep_prob, softmax=True)

    sess = tf.InteractiveSession()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, os.path.join(LOG_DIR, 'model.ckpt'))

    softmax = sess.run(logits, feed_dict={keep_prob: 1.0}).flatten()
    print_results(imagepath, softmax)
예제 #4
0
class Prob(object):
    def __init__(self):
        self.c = CNN()
        self.images_placeholder = tf.placeholder("float")
        self.keep_prob = tf.placeholder("float")
        self.softmax = self.c.inference(self.images_placeholder,
                                        self.keep_prob)
        self.saver = tf.train.Saver()

    def get_prob(self, image_id):

        save_path = get_save_path(image_id)
        face_list = [img for img in os.listdir(save_path) if "face_" in img]
        p = {}
        for image_name in face_list:
            _image_num = re.search(r"_\d+_", image_name).group()
            image_num = _image_num[1:][:-1]
            CVimage = cv2.imread(save_path + "/" + image_name)
            image = self.c.shape_CVimage(CVimage)
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state('./')
                if ckpt:  # checkpointがある場合
                    last_model = ckpt.model_checkpoint_path  # 最後に保存したmodelへのパス
                    self.saver.restore(sess, last_model)  # 変数データの読み込み
                #self.saver.restore(sess, "/root/share/domain/model.ckpt")
                prob = sess.run(self.softmax,
                                feed_dict={
                                    self.images_placeholder: [image],
                                    self.keep_prob: 1.0
                                })[0][0]

            if prob:
                p[image_num] = float(prob)
            else:
                res = {
                    "status": "error",
                    "message": "can not get valid probability",
                }
                return res
        res = {
            "status": "success",
            "data_type": "detail",
            "detail": {
                "probability": p
            }
        }
        return res
예제 #5
0
def export():
    output_dir = os.path.join(FLAGS.ckpt, "export")
    if tf.gfile.Exists(output_dir) == True:
        tf.gfile.DeleteRecursively(output_dir)

    if tf.gfile.Exists(output_dir) == False:
        tf.gfile.MakeDirs(output_dir)

    ### EDIT TRAINING GRAPH ###
    with tf.Graph().as_default() as g:
        nn = CNN(FLAGS.network, FLAGS.num_of_classes, 1, FLAGS.image_size,
                 FLAGS.image_crop_size, FLAGS.log_input, FLAGS.grayscale,
                 FLAGS.log_feature, FLAGS.use_fp16)

        with tf.device("/cpu:0"):
            image = tf.placeholder(tf.float32,
                                   shape=(None, None, 3),
                                   name="image")
            image = nn.input(image, True, False)

        dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
        keep_prob = tf.constant(1.0, dtype=dtype)
        batch_size = tf.constant(1)
        logits = nn.inference(image, keep_prob, batch_size)

        # Calculate predictions.
        logits = tf.cast(logits, tf.float32)
        softmax = tf.nn.softmax(logits)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            CNN.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        """
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(output_dir, g)
        """
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return

            tf.train.Saver().save(
                sess,
                os.path.join(output_dir, 'model.ckpt'),
                global_step=tf.convert_to_tensor(global_step))
            tf.train.write_graph(sess.graph.as_graph_def(),
                                 output_dir,
                                 'graph.pbtxt',
                                 as_text=True)
            """
            summary = tf.Summary()
            summary.ParseFromString(sess.run(summary_op))
            summary_writer.add_summary(summary, global_step)
            """

    ### EXPORT MODEL ###
    graph_path = os.path.join(output_dir, 'graph.pbtxt')
    if tf.gfile.Exists(graph_path) == False:
        raise ValueError('Graph not found({})'.format(graph_path))

    ckpt = tf.train.get_checkpoint_state(output_dir)
    ckpt_path = ckpt.model_checkpoint_path

    if ckpt == False or ckpt_path == False:
        raise ValueError('Check point not found.')

    output_path = os.path.join(output_dir, 'frozen.pb')
    optimized_output_path = os.path.join(output_dir, 'optimized.pb')

    freeze_graph.freeze_graph(input_graph=graph_path,
                              input_saver="",
                              input_binary=False,
                              input_checkpoint=ckpt_path,
                              output_node_names="softmax_linear/softmax",
                              restore_op_name="save/restore_all",
                              filename_tensor_name="save/Const:0",
                              output_graph=output_path,
                              clear_devices=True,
                              initializer_nodes="")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open(output_path, "r") as f:
        data = f.read()
        input_graph_def.ParseFromString(data)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def, ['image'], ["softmax_linear/softmax"],
        tf.float32.as_datatype_enum)

    f = tf.gfile.FastGFile(optimized_output_path, "w")
    f.write(output_graph_def.SerializeToString())

    output_size = os.path.getsize(output_path)
    optimized_output_size = os.path.getsize(optimized_output_path)

    print('Model Exported successfuly.')
    print('- Frozen Model: {} ({})'.format(output_path,
                                           _humansize(output_size)))
    print('- Optimized Model: {} ({})'.format(
        optimized_output_path, _humansize(optimized_output_size)))
예제 #6
0
def evaluate(data_dir):
    # PreProcess
    if not data_dir:
        raise ValueError('Please supply a data_dir')

    if tf.gfile.Exists(FLAGS.ckpt) == False:
        raise ValueError('Please supply a checkpoint')

    if FLAGS.event_dir == None:
        FLAGS.event_dir = os.path.join(FLAGS.ckpt, 'event')
        #raise ValueError('Please supply a event_dir')

    if tf.gfile.Exists(FLAGS.event_dir):
        tf.gfile.DeleteRecursively(FLAGS.event_dir)
    tf.gfile.MakeDirs(FLAGS.event_dir)

    num_of_classes = _num_of_folders(data_dir)
    if num_of_classes <= 0:
        raise ValueError('Invalid data_dir')

    num_of_samples = _num_of_files(data_dir)
    if num_of_samples == None:
        raise ValueError('Please supply num_of_samples.')

    print('[ SUMMARY ]')
    print('Num of classes: {}'.format(num_of_classes))
    print('Num of samples: {}'.format(num_of_samples))
    """Eval CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        info = [['Number of classes', str(num_of_classes)],
                ['Number of samples', str(num_of_samples)],
                ['Image size', str(FLAGS.image_size)],
                ['Image Crop size',
                 str(FLAGS.image_crop_size)],
                ['Grayscale', str(FLAGS.grayscale)],
                ['Use Float16', str(FLAGS.use_fp16)]]
        info_summary = tf.summary.text('NetworkInfo',
                                       tf.convert_to_tensor(info),
                                       collections=[])

        nn = CNN(FLAGS.network, num_of_classes, num_of_samples,
                 FLAGS.image_size, FLAGS.image_crop_size, FLAGS.log_input,
                 FLAGS.grayscale, FLAGS.log_feature, FLAGS.use_fp16)

        with tf.device("/cpu:0"):
            # Get images and labels for CIFAR-10.
            images, labels, filenames = nn.inputs(data_dir, FLAGS.batch_size,
                                                  FLAGS.batch_size * 100,
                                                  FLAGS.scaling, False)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
        keep_prob = tf.constant(1.0, dtype=dtype)
        batch_size = tf.constant(FLAGS.batch_size)
        logits = nn.inference(images, keep_prob, batch_size)

        # Calculate predictions.
        logits = tf.cast(logits, tf.float32)
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            CNN.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.event_dir, g)

        #eval_once(saver, summary_writer, top_k_op, summary_op, labels, keep_prob)
        with tf.Session() as sess:
            summary_writer.add_summary(sess.run(info_summary))

            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                # Restores from checkpoint
                saver.restore(sess, ckpt.model_checkpoint_path)
                # Assuming model_checkpoint_path looks something like:
                #    /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return

            # Start the queue runners.
            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))

                num_iter = int(math.ceil(num_of_samples / FLAGS.batch_size))
                true_count = 0  # Counts the number of correct predictions.
                total_sample_count = num_iter * FLAGS.batch_size
                step = 0
                #print(sess.run(labels, {keep_prob: 1.0}))

                errors = []
                while step < num_iter and not coord.should_stop():
                    predictions, labels_value, filenames_value = sess.run(
                        [top_k_op, labels, filenames])
                    true_count += np.sum(predictions)
                    step += 1

                    print(labels_value)
                    #print(predictions)
                    errors += [
                        x for i, x in enumerate(filenames_value)
                        if predictions[i] == False
                    ]

                # Print errors
                print('Errors:')
                print('\n'.join(f for f in errors))

                # Copy Errors
                copyErrors = True if FLAGS.error_dir != None else False
                if copyErrors:
                    _copy_errors(errors, FLAGS.error_dir)

                # Compute precision @ 1.
                precision = true_count / total_sample_count
                print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                summary.value.add(tag='Precision @ 1', simple_value=precision)
                summary_writer.add_summary(summary, global_step)
            except Exception as e:  # pylint: disable=broad-except
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)
예제 #7
0
def evaluate_single():
    # PreProcess
    if not FLAGS.image:
        raise ValueError('Please supply a image')
    if tf.gfile.Exists(FLAGS.image) == False:
        raise ValueError('Image not found.')

    if tf.gfile.Exists(FLAGS.ckpt) == False:
        raise ValueError('Please supply a checkpoint')

    if FLAGS.num_of_classes == None:
        raise ValueError('Please supply num_of_classes.')

    if FLAGS.event_dir == None:
        FLAGS.event_dir = os.path.join(FLAGS.ckpt, 'event')
        #raise ValueError('Please supply a event_dir')

    if tf.gfile.Exists(FLAGS.event_dir):
        tf.gfile.DeleteRecursively(FLAGS.event_dir)
    tf.gfile.MakeDirs(FLAGS.event_dir)

    with tf.Graph().as_default() as g:
        nn = CNN(FLAGS.network, FLAGS.num_of_classes, 1, FLAGS.image_size,
                 FLAGS.image_crop_size, True, FLAGS.grayscale, True,
                 FLAGS.use_fp16)

        with tf.device("/cpu:0"):
            file_data = tf.read_file(FLAGS.image)
            image = tf.image.decode_jpeg(file_data, channels=3)
            image = nn.input(image, True, False)

        dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
        keep_prob = tf.constant(1.0, dtype=dtype)
        batch_size = tf.constant(1)
        logits = nn.inference(image, keep_prob, batch_size)

        # Calculate predictions.
        logits = tf.cast(logits, tf.float32)
        softmax = tf.nn.softmax(logits)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            CNN.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        info = [['Image', str(FLAGS.image)],
                ['Image size', str(FLAGS.image_size)],
                ['Image Crop size',
                 str(FLAGS.image_crop_size)],
                ['Grayscale', str(FLAGS.grayscale)],
                ['Use Float16', str(FLAGS.use_fp16)]]
        info_summary = tf.summary.text('Info',
                                       tf.convert_to_tensor(info),
                                       collections=[])

        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.event_dir, g)

        with tf.Session() as sess:
            summary_writer.add_summary(sess.run(info_summary))

            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return

            result = sess.run(softmax)
            result = result[0]

            print('- Result')
            for i, v in enumerate(result):
                print('[%d]: %f' % (i, v))

            summary = tf.Summary()
            summary.ParseFromString(sess.run(summary_op))
            summary_writer.add_summary(summary, global_step)
예제 #8
0
def train(data_dir):
    # PreProcess
    if not data_dir:
        raise ValueError('Please supply a data_dir')

    num_of_classes = _num_of_folders(data_dir)
    if num_of_classes <= 0:
        raise ValueError('Invalid num_of_classes')

    num_of_samples = _num_of_files(data_dir)
    if num_of_samples == None:
        raise ValueError('Please supply num_of_samples.')

    print('[ SUMMARY ]')
    print('Num of classes: {}'.format(num_of_classes))
    print('Num of samples: {}'.format(num_of_samples))

    # Training
    with tf.Graph().as_default():
        info = [['Number of classes', str(num_of_classes)],
                ['Number of samples', str(num_of_samples)],
                ['Image size', str(FLAGS.image_size)],
                ['Image Crop size',
                 str(FLAGS.image_crop_size)],
                ['Grayscale', str(FLAGS.grayscale)],
                ['Use Float16', str(FLAGS.use_fp16)]]
        tf.summary.text('NetworkInfo',
                        tf.convert_to_tensor(info),
                        collections=[])

        nn = CNN(FLAGS.network, num_of_classes, num_of_samples,
                 FLAGS.image_size, FLAGS.image_crop_size, FLAGS.log_input,
                 FLAGS.grayscale, FLAGS.log_feature, FLAGS.use_fp16)
        global_step = tf.contrib.framework.get_or_create_global_step()

        with tf.device('/gpu:0'):
            images, labels, filenames = nn.inputs(data_dir, FLAGS.batch_size,
                                                  FLAGS.batch_size * 500,
                                                  FLAGS.scaling,
                                                  FLAGS.destorted)

        dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
        keep_prob = tf.placeholder(dtype, name="keep_prob")
        batch_size = tf.placeholder(tf.float32, name="batch_size")
        logits = nn.inference(images, keep_prob, batch_size)
        loss, accuracy = nn.loss(logits, labels)

        train_op = nn.train(loss, global_step, FLAGS.batch_size,
                            FLAGS.learning_rate, FLAGS.num_epochs_per_decay,
                            FLAGS.learning_rate_decay_factor)

        class _LoggerHook(tf.train.SessionRunHook):
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs([loss, accuracy])

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value, accuracy_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, accuracy = %.2f, loss = %.2f (%1.f examples/sec; %.3f sec/batch)'
                    )
                    print(format_str %
                          (datetime.now(), self._step, accuracy_value,
                           loss_value, examples_per_sec, sec_per_batch))

        conf = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement,
                              allow_soft_placement=True,
                              intra_op_parallelism_threads=8)

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.ckpt,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_summaries_steps=FLAGS.save_steps,
                config=conf) as mon_sess:

            while not mon_sess.should_stop():
                mon_sess.run(train_op, {
                    keep_prob: FLAGS.keep_prob,
                    batch_size: FLAGS.batch_size
                })
예제 #9
0
파일: train.py 프로젝트: quanon/ppp
train_images, train_labels = fetch_images_and_labels(TRAIN_DIR)
train_images, train_labels = shaffle_images_and_labels(train_images,
                                                       train_labels)

test_images, test_labels = fetch_images_and_labels(TEST_DIR)
test_images, test_labels = shaffle_images_and_labels(test_images, test_labels)

cnn = CNN(image_size=FLAGS.image_size, class_count=len(CLASSES))

with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [None, PIXEL_COUNT])
    labels = tf.placeholder(tf.float32, [None, len(CLASSES)])
    keep_prob = tf.placeholder(tf.float32)

    y = cnn.inference(x, keep_prob)
    v = cnn.cross_entropy(y, labels)
    train_step = cnn.train_step(v, FLAGS.learning_rate)
    accuracy = cnn.accuracy(y, labels)

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)

        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)

        for i in range(FLAGS.step_count):
            for j in range(int(len(train_images) / FLAGS.batch_size)):
예제 #10
0
파일: inference.py 프로젝트: quanon/ppp
    cv2.putText(image, CLASSES[prediction], (x, y), cv2.FONT_HERSHEY_DUPLEX,
                0.5, (255, 255, 102), 1, CV_AA)

    return image


test_image = cv2.imread(sys.argv[1])
face_rects = detect(test_image)
face_images = get_face_images(test_image, face_rects)

x = tf.placeholder(tf.float32, [None, PIXEL_COUNT])
labels = tf.placeholder(tf.float32, [None, len(CLASSES)])
keep_prob = tf.placeholder(tf.float32)

cnn = CNN(image_size=FLAGS.image_size, class_count=len(CLASSES))
y = cnn.inference(x, keep_prob, softmax=True)
sess = tf.InteractiveSession()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.restore(sess, os.path.join(LOG_DIR, 'model.ckpt'))

out_image = test_image.copy()

for i in range(len(face_images)):
    face_image = face_images[i]
    softmax = sess.run(y, feed_dict={
        x: [face_image],
        keep_prob: 1.0
    }).flatten()
    prediction = np.argmax(softmax)
예제 #11
0
def export_serving():
    output_dir = os.path.join(FLAGS.ckpt, "export_serving")
    if tf.gfile.Exists(output_dir) == True:
        tf.gfile.DeleteRecursively(output_dir)

    if tf.gfile.Exists(output_dir) == False:
        tf.gfile.MakeDirs(output_dir)

    ### EDIT TRAINING GRAPH ###
    with tf.Graph().as_default() as g:
        nn = CNN(FLAGS.network, FLAGS.num_of_classes, 1, FLAGS.image_size,
                 FLAGS.image_crop_size, FLAGS.log_input, FLAGS.grayscale,
                 FLAGS.log_feature, FLAGS.use_fp16)

        with tf.device("/cpu:0"):
            image = tf.placeholder(tf.float32,
                                   shape=(None, None, 3),
                                   name="image")
            image = nn.input(image, True, False)

        dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
        keep_prob = tf.constant(1.0, dtype=dtype)
        batch_size = tf.constant(1)
        logits = nn.inference(image, keep_prob, batch_size)

        # Calculate predictions.
        logits = tf.cast(logits, tf.float32)
        softmax = tf.nn.softmax(logits)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            CNN.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        """
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(output_dir, g)
        """
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(FLAGS.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return

            tf.train.Saver().save(
                sess,
                os.path.join(output_dir, 'model.ckpt'),
                global_step=tf.convert_to_tensor(global_step))
            tf.train.write_graph(sess.graph.as_graph_def(),
                                 output_dir,
                                 'graph.pbtxt',
                                 as_text=True)
            """
            summary = tf.Summary()
            summary.ParseFromString(sess.run(summary_op))
            summary_writer.add_summary(summary, global_step)
            """

            ### EXPORT MODEL ###
            export_path_base = output_dir
            export_path = os.path.join(
                tf.compat.as_bytes(export_path_base),
                tf.compat.as_bytes(str(FLAGS.model_version)))
            print('Exporting model to '.format(export_path))

            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            # Build the signature_def_map.
            classification_inputs = tf.saved_model.utils.build_tensor_info(
                image)
            classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
                softmax)
            classification_outputs_scores = tf.saved_model.utils.build_tensor_info(
                softmax)

            classification_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                        classification_inputs
                    },
                    outputs={
                        tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                        classification_outputs_classes,
                        tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
                        classification_outputs_scores
                    },
                    method_name=tf.saved_model.signature_constants.
                    CLASSIFY_METHOD_NAME))

            tensor_info_x = tf.saved_model.utils.build_tensor_info(image)
            tensor_info_y = tf.saved_model.utils.build_tensor_info(softmax)

            prediction_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={'image': tensor_info_x},
                    outputs={'score': tensor_info_y},
                    method_name=tf.saved_model.signature_constants.
                    PREDICT_METHOD_NAME))

            legacy_init_op = tf.group(tf.tables_initializer(),
                                      name='legacy_init_op')
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    'predict_image':
                    prediction_signature,
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    classification_signature,
                },
                legacy_init_op=legacy_init_op)

            builder.save()

            print('Done exporting!')