Beispiel #1
0
def train(re_train=True):
	"""Train CIFAR-10 for a number of steps."""
	images, labels = get_images_and_labels()

	# print "found " + str(len(images)) + " images"
	train_size = 5	
	train_images = images[:train_size, :, :, :]
	train_labels = labels[:train_size]

	val_images = images[train_size:, :, :, :]
	val_labels = labels[train_size:]

  	with tf.Graph().as_default():
		global_step = tf.Variable(0, trainable=False)

		images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

		print (images.get_shape(), val_images.get_shape())
		logits = my_cifar.inference(images_placeholder)

		# Calculate loss.
		loss = my_cifar.loss(logits, labels_placeholder)

		# Build a Graph that trains the model with one batch of examples and
		# updates the model parameters.
		train_op = my_cifar.training(loss, global_step)

		# Calculate accuracy #
		acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

		# Create a saver.
		saver = tf.train.Saver()

		tf.scalar_summary('Acc', acc)
		# tf.scalar_summary('Val Acc', acc_val)
		tf.scalar_summary('Loss', loss)
		tf.image_summary('Images', tf.reshape(images, shape=[-1, 40, 40, 3]), max_images=10)
		tf.image_summary('Val Images', tf.reshape(val_images, shape=[-1, 40, 40, 3]), max_images=10)

		# Build the summary operation based on the TF collection of Summaries.
		summary_op = tf.merge_all_summaries()

		# Build an initialization operation to run below.
		init = tf.initialize_all_variables()

		# Start running operations on the Graph.
		# NUM_CORES = 2  # Choose how many cores to use.
		sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement, ))
		# inter_op_parallelism_threads=NUM_CORES,
		# intra_op_parallelism_threads=NUM_CORES))
		sess.run(init)

		# Write all terminal output results here
		val_f = open("tmp/val.txt", "ab")

		# Start the queue runners.
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(sess=sess, coord=coord)

		summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr,
                                    graph_def=sess.graph_def)


		# ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
		# print ckpt.model_checkpoint_path
		# if ckpt and ckpt.model_checkpoint_path:
		# 	saver.restore(sess, ckpt.model_checkpoint_path)
		# print('Restored!')

		for i in range(10):
			images_val_r, labels_val_r = sess.run([val_images, val_labels])
			val_feed = {images_placeholder: images_val_r,
			            labels_placeholder: labels_val_r}

		tf.scalar_summary('Acc', acc)

		print('Calculating Acc: ')

		acc_r = sess.run(acc, feed_dict=val_feed)
		print(acc_r)

		coord.join(threads)
		sess.close()
Beispiel #2
0
def train(re_train=True):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images_placeholder, labels_placeholder = placeholder_inputs(
            FLAGS.batch_size)

        # Get images and labels for CIFAR-10.
        # images, labels = my_input.inputs()
        images, labels = imageflow.distorted_inputs(
            filename='../my_data_raw/train.tfrecords',
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            num_threads=5,
            imshape=[32, 32, 3],
            imsize=32)
        val_images, val_labels = imageflow.inputs(
            filename='../my_data_raw/validation.tfrecords',
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            num_threads=5,
            imshape=[32, 32, 3])

        print(images.get_shape(), val_images.get_shape())
        # Build a Graph that computes the logits predictions from the inference model.
        logits = my_cifar.inference(images_placeholder)

        # Calculate loss.
        loss = my_cifar.loss(logits, labels_placeholder)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = my_cifar.training(loss, global_step)

        # Calculate accuracy #
        acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

        # Create a saver.
        saver = tf.train.Saver()

        tf.scalar_summary('Acc', acc)
        # tf.scalar_summary('Val Acc', acc_val)
        tf.scalar_summary('Loss', loss)
        tf.image_summary('Images',
                         tf.reshape(images, shape=[-1, 32, 32, 3]),
                         max_images=10)
        tf.image_summary('Val Images',
                         tf.reshape(val_images, shape=[-1, 32, 32, 3]),
                         max_images=10)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        # NUM_CORES = 2  # Choose how many cores to use.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, ))
        # inter_op_parallelism_threads=NUM_CORES,
        # intra_op_parallelism_threads=NUM_CORES))
        sess.run(init)

        # Write all terminal output results here
        val_f = open("tmp/val.txt", "ab")

        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr,
                                                graph_def=sess.graph_def)

        if re_train:

            # Export graph to import it later in c++
            # tf.train.write_graph(sess.graph_def, FLAGS.model_dir, 'train.pbtxt') # TODO: uncomment to get graph and use in c++

            continue_from_pre = False

            if continue_from_pre:
                ckpt = tf.train.get_checkpoint_state(
                    checkpoint_dir=FLAGS.checkpoint_dir)
                print ckpt.model_checkpoint_path
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Session Restored!')

            try:
                while not coord.should_stop():

                    for step in xrange(FLAGS.max_steps):

                        images_r, labels_r = sess.run([images, labels])
                        images_val_r, labels_val_r = sess.run(
                            [val_images, val_labels])

                        train_feed = {
                            images_placeholder: images_r,
                            labels_placeholder: labels_r
                        }

                        val_feed = {
                            images_placeholder: images_val_r,
                            labels_placeholder: labels_val_r
                        }

                        start_time = time.time()

                        _, loss_value = sess.run([train_op, loss],
                                                 feed_dict=train_feed)
                        duration = time.time() - start_time

                        assert not np.isnan(
                            loss_value), 'Model diverged with loss = NaN'

                        if step % display_step == 0:
                            num_examples_per_step = FLAGS.batch_size
                            examples_per_sec = num_examples_per_step / duration
                            sec_per_batch = float(duration)

                            format_str = (
                                '%s: step %d, loss = %.6f (%.1f examples/sec; %.3f '
                                'sec/batch)')
                            print_str_loss = format_str % (
                                datetime.now(), step, loss_value,
                                examples_per_sec, sec_per_batch)
                            print(print_str_loss)
                            val_f.write(print_str_loss + NEW_LINE)
                            summary_str = sess.run([summary_op],
                                                   feed_dict=train_feed)
                            summary_writer.add_summary(summary_str[0], step)

                        if step % val_step == 0:
                            acc_value, num_corroect = sess.run(
                                [acc, n_correct], feed_dict=train_feed)

                            format_str = '%s: step %d,  train acc = %.2f, n_correct= %d'
                            print_str_train = format_str % (
                                datetime.now(), step, acc_value, num_corroect)
                            val_f.write(print_str_train + NEW_LINE)
                            print(print_str_train)

                        # Save the model checkpoint periodically.
                        if step % save_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                            val_acc_r, val_n_correct_r = sess.run(
                                [acc, n_correct], feed_dict=val_feed)

                            frmt_str = ' step %d, Val Acc = %.2f, num correct = %d'
                            print_str_val = frmt_str % (step, val_acc_r,
                                                        val_n_correct_r)
                            val_f.write(print_str_val + NEW_LINE)
                            print(print_str_val)

                            checkpoint_path = os.path.join(
                                FLAGS.checkpoint_dir, 'model.ckpt')
                            saver.save(sess, checkpoint_path, global_step=step)

            except tf.errors.OutOfRangeError:
                print('Done training -- epoch limit reached')

            finally:
                # When done, ask the threads to stop.
                val_f.write(
                    NEW_LINE + NEW_LINE +
                    '############################ FINISHED ############################'
                    + NEW_LINE)
                val_f.close()
                coord.request_stop()

            # Wait for threads to finish.
            coord.join(threads)
            sess.close()

        else:

            ckpt = tf.train.get_checkpoint_state(
                checkpoint_dir=FLAGS.checkpoint_dir)
            print ckpt.model_checkpoint_path
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Restored!')

            for i in range(100):
                images_val_r, labels_val_r = sess.run([val_images, val_labels])
                val_feed = {
                    images_placeholder: images_val_r,
                    labels_placeholder: labels_val_r
                }

                tf.scalar_summary('Acc', acc)

                print('Calculating Acc: ')

                acc_r = sess.run(acc, feed_dict=val_feed)
                print(acc_r)

        coord.join(threads)
        sess.close()
Beispiel #3
0
def train(re_train=True):
    """Train CIFAR-10 for a number of steps."""
    images, labels = get_images_and_labels()

    # print "found " + str(len(images)) + " images"
    train_size = 5
    train_images = images[:train_size, :, :, :]
    train_labels = labels[:train_size]

    val_images = images[train_size:, :, :, :]
    val_labels = labels[train_size:]

    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images_placeholder, labels_placeholder = placeholder_inputs(
            FLAGS.batch_size)

        print(images.get_shape(), val_images.get_shape())
        logits = my_cifar.inference(images_placeholder)

        # Calculate loss.
        loss = my_cifar.loss(logits, labels_placeholder)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = my_cifar.training(loss, global_step)

        # Calculate accuracy #
        acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

        # Create a saver.
        saver = tf.train.Saver()

        tf.scalar_summary('Acc', acc)
        # tf.scalar_summary('Val Acc', acc_val)
        tf.scalar_summary('Loss', loss)
        tf.image_summary('Images',
                         tf.reshape(images, shape=[-1, 40, 40, 3]),
                         max_images=10)
        tf.image_summary('Val Images',
                         tf.reshape(val_images, shape=[-1, 40, 40, 3]),
                         max_images=10)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        # NUM_CORES = 2  # Choose how many cores to use.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, ))
        # inter_op_parallelism_threads=NUM_CORES,
        # intra_op_parallelism_threads=NUM_CORES))
        sess.run(init)

        # Write all terminal output results here
        val_f = open("tmp/val.txt", "ab")

        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr,
                                                graph_def=sess.graph_def)

        # ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
        # print ckpt.model_checkpoint_path
        # if ckpt and ckpt.model_checkpoint_path:
        # 	saver.restore(sess, ckpt.model_checkpoint_path)
        # print('Restored!')

        for i in range(10):
            images_val_r, labels_val_r = sess.run([val_images, val_labels])
            val_feed = {
                images_placeholder: images_val_r,
                labels_placeholder: labels_val_r
            }

        tf.scalar_summary('Acc', acc)

        print('Calculating Acc: ')

        acc_r = sess.run(acc, feed_dict=val_feed)
        print(acc_r)

        coord.join(threads)
        sess.close()
def train(re_train=True):
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

    # Get images and labels for CIFAR-10.
    # images, labels = my_input.inputs()
    images, labels = input_data.distorted_inputs()
    val_images, val_labels = input_data.inputs(False)

    # Build a Graph that computes the logits predictions from the inference model.
    logits = my_cifar.inference(images_placeholder)

    # Calculate loss.
    loss = my_cifar.loss(logits, labels_placeholder)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = my_cifar.training(loss, global_step)

    # Calculate accuracy #
    acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

    # Create a saver.
    saver = tf.train.Saver()

    tf.scalar_summary('Acc', acc)
    # tf.scalar_summary('Val Acc', acc_val)
    tf.scalar_summary('Loss', loss)
    tf.image_summary('Images', tf.reshape(images, shape=[-1, 32, 32, 3]), max_images=10)
    tf.image_summary('Val Images', tf.reshape(val_images, shape=[-1, 32, 32, 3]), max_images=10)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    # NUM_CORES = 2  # Choose how many cores to use.
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement, ))
    # inter_op_parallelism_threads=NUM_CORES,
    # intra_op_parallelism_threads=NUM_CORES))
    sess.run(init)

    # Write all terminal output results here
    val_f = open("tmp/val.txt", "ab")

    # Start the queue runners.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr,
                                            graph_def=sess.graph_def)

    if re_train:

      # Export graph to import it later in c++
      # tf.train.write_graph(sess.graph_def, FLAGS.model_dir, 'train.pbtxt') # TODO: uncomment to get graph and use in c++

      continue_from_pre = False

      if continue_from_pre:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
        print ckpt.model_checkpoint_path
        if ckpt and ckpt.model_checkpoint_path:
          saver.restore(sess, ckpt.model_checkpoint_path)
          print('Session Restored!')

      try:
        while not coord.should_stop():

          for step in xrange(FLAGS.max_steps):

            images_r, labels_r = sess.run([images, labels])
            images_val_r, labels_val_r = sess.run([val_images, val_labels])

            train_feed = {images_placeholder: images_r,
                          labels_placeholder: labels_r}

            val_feed = {images_placeholder: images_val_r,
                        labels_placeholder: labels_val_r}

            start_time = time.time()

            _, loss_value = sess.run([train_op, loss], feed_dict=train_feed)
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % display_step == 0:
              num_examples_per_step = FLAGS.batch_size
              examples_per_sec = num_examples_per_step / duration
              sec_per_batch = float(duration)

              format_str = ('%s: step %d, loss = %.6f (%.1f examples/sec; %.3f '
                            'sec/batch)')
              print_str_loss = format_str % (datetime.now(), step, loss_value,
                                             examples_per_sec, sec_per_batch)
              print (print_str_loss)
              val_f.write(print_str_loss + NEW_LINE)
              summary_str = sess.run([summary_op], feed_dict=train_feed)
              summary_writer.add_summary(summary_str[0], step)

            if step % val_step == 0:
              acc_value, num_corroect = sess.run([acc, n_correct], feed_dict=train_feed)

              format_str = '%s: step %d,  train acc = %.2f, n_correct= %d'
              print_str_train = format_str % (datetime.now(), step, acc_value, num_corroect)
              val_f.write(print_str_train + NEW_LINE)
              print (print_str_train)

            # Save the model checkpoint periodically.
            if step % save_step == 0 or (step + 1) == FLAGS.max_steps:
              val_acc_r, val_n_correct_r = sess.run([acc, n_correct], feed_dict=val_feed)

              frmt_str = ' step %d, Val Acc = %.2f, num correct = %d'
              print_str_val = frmt_str % (step, val_acc_r, val_n_correct_r)
              val_f.write(print_str_val + NEW_LINE)
              print(print_str_val)

              checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt')
              saver.save(sess, checkpoint_path, global_step=step)


      except tf.errors.OutOfRangeError:
        print ('Done training -- epoch limit reached')

      finally:
        # When done, ask the threads to stop.
        val_f.write(NEW_LINE +
                    NEW_LINE +
                    '############################ FINISHED ############################' +
                    NEW_LINE)
        val_f.close()
        coord.request_stop()

      # Wait for threads to finish.
      coord.join(threads)
      sess.close()

    else:

      ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
      print ckpt.model_checkpoint_path
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Restored!')

      for i in range(100):
        images_val_r, labels_val_r = sess.run([val_images, val_labels])
        val_feed = {images_placeholder: images_val_r,
                    labels_placeholder: labels_val_r}

        tf.scalar_summary('Acc', acc)

        print('Calculating Acc: ')

        acc_r = sess.run(acc, feed_dict=val_feed)
        print(acc_r)

    coord.join(threads)
    sess.close()
Beispiel #5
0
def train(TRAIN=True):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():

        global_step = tf.Variable(0, trainable=False)

        images_placeholder, labels_placeholder = placeholder_inputs(
            FLAGS.batch_size)
        # Get images and labels for CIFAR-10.
        images, labels = my_input.distorted_inputs()
        val_images, val_labels = my_input.inputs(False)
        print('2- images shape is ', images.get_shape())
        print('3- labels shape is ', labels.get_shape())

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = my_cifar.inference(images_placeholder)
        print('4- logits shape is ', logits.get_shape())

        # Calculate loss.
        loss = my_cifar.loss(logits, labels_placeholder)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = my_cifar.training(loss, global_step)

        # Calculate accuracy ##
        acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

        # Create a saver.
        saver = tf.train.Saver()

        tf.scalar_summary('Acc', acc)
        tf.scalar_summary('Loss', loss)
        tf.image_summary('Images',
                         tf.reshape(images, shape=[-1, 32, 32, 3]),
                         max_images=10)
        tf.image_summary('Val Images',
                         tf.reshape(val_images, shape=[-1, 32, 32, 3]),
                         max_images=10)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr,
                                                graph_def=sess.graph_def)

        if TRAIN:
            try:
                while not coord.should_stop():

                    for step in xrange(FLAGS.max_steps):

                        images_r, labels_r = sess.run([images, labels])
                        images_val_r, labels_val_r = sess.run(
                            [val_images, val_labels])

                        train_feed = {
                            images_placeholder: images_r,
                            labels_placeholder: labels_r
                        }

                        val_feed = {
                            images_placeholder: images_val_r,
                            labels_placeholder: labels_val_r
                        }

                        start_time = time.time()

                        _, loss_value = sess.run([train_op, loss],
                                                 feed_dict=train_feed)
                        duration = time.time() - start_time

                        assert not np.isnan(
                            loss_value), 'Model diverged with loss = NaN'

                        if step % display_step == 0:
                            num_examples_per_step = FLAGS.batch_size
                            examples_per_sec = num_examples_per_step / duration
                            sec_per_batch = float(duration)

                            format_str = (
                                '%s: step %d, loss = %.6f (%.1f examples/sec; %.3f '
                                'sec/batch)')
                            print_str_loss = format_str % (
                                datetime.now(), step, loss_value,
                                examples_per_sec, sec_per_batch)
                            print(print_str_loss)
                            summary_str = sess.run([summary_op],
                                                   feed_dict=train_feed)
                            summary_writer.add_summary(summary_str[0], step)

                        if step % (display_step * 5) == 0:
                            acc_value, num_corroect = sess.run(
                                [acc, n_correct], feed_dict=train_feed)

                            format_str = '%s: step %d,  train acc = %.2f, n_correct= %d'
                            print_str_train = format_str % (
                                datetime.now(), step, acc_value, num_corroect)
                            print(print_str_train)

                        # Save the model checkpoint periodically.
                        if (step + 1) % (display_step * 10) == 0 or (
                                step + 1) == FLAGS.max_steps:
                            val_acc_r, val_n_correct_r = sess.run(
                                [acc, n_correct],
                                feed_dict=val_feed)  # , feed_dict=val_feed

                            frmt_str = 'Step %d, Val Acc = %.2f, num correct = %d'
                            print_str_val = frmt_str % (step, val_acc_r,
                                                        val_n_correct_r)
                            print(print_str_val)

                            # checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt')
                            # saver.save(sess, checkpoint_path, global_step=step, latest_filename=checkpoint_state_name)

                            checkpoint_prefix = os.path.join(
                                FLAGS.checkpoint_dir, "saved_checkpoint")
                            saver.save(sess,
                                       checkpoint_prefix,
                                       global_step=0,
                                       latest_filename=checkpoint_state_name)

            except tf.errors.OutOfRangeError:
                print('Done training -- epoch limit reached')

            finally:
                # When done, ask the threads to stop.
                coord.request_stop()
                '''
         TODO #3.1: Start freezing the graph when training finished
        '''
                freeze_my_graph(sess)

            # Wait for threads to finish.
            coord.join(threads)
            sess.close()

        # If you define TRAIN argument to False, so it will load from the checkpoint file and freeze.
        else:
            '''
         TODO #3.2: You can also freeze the graph from the latest checkpoint if you don't want to wait for a long time.
      '''
            freeze_my_graph(sess)
def train(TRAIN=True):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():

        global_step = tf.Variable(0, trainable=False)

        images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)
        # Get images and labels for CIFAR-10.
        images, labels = my_input.distorted_inputs()
        val_images, val_labels = my_input.inputs(False)
        print("2- images shape is ", images.get_shape())
        print("3- labels shape is ", labels.get_shape())

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = my_cifar.inference(images_placeholder)
        print("4- logits shape is ", logits.get_shape())

        # Calculate loss.
        loss = my_cifar.loss(logits, labels_placeholder)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = my_cifar.training(loss, global_step)

        # Calculate accuracy ##
        acc, n_correct = my_cifar.evaluation(logits, labels_placeholder)

        # Create a saver.
        saver = tf.train.Saver()

        tf.scalar_summary("Acc", acc)
        tf.scalar_summary("Loss", loss)
        tf.image_summary("Images", tf.reshape(images, shape=[-1, 32, 32, 3]), max_images=10)
        tf.image_summary("Val Images", tf.reshape(val_images, shape=[-1, 32, 32, 3]), max_images=10)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr, graph_def=sess.graph_def)

        if TRAIN:
            try:
                while not coord.should_stop():

                    for step in xrange(FLAGS.max_steps):

                        images_r, labels_r = sess.run([images, labels])
                        images_val_r, labels_val_r = sess.run([val_images, val_labels])

                        train_feed = {images_placeholder: images_r, labels_placeholder: labels_r}

                        val_feed = {images_placeholder: images_val_r, labels_placeholder: labels_val_r}

                        start_time = time.time()

                        _, loss_value = sess.run([train_op, loss], feed_dict=train_feed)
                        duration = time.time() - start_time

                        assert not np.isnan(loss_value), "Model diverged with loss = NaN"

                        if step % display_step == 0:
                            num_examples_per_step = FLAGS.batch_size
                            examples_per_sec = num_examples_per_step / duration
                            sec_per_batch = float(duration)

                            format_str = "%s: step %d, loss = %.6f (%.1f examples/sec; %.3f " "sec/batch)"
                            print_str_loss = format_str % (
                                datetime.now(),
                                step,
                                loss_value,
                                examples_per_sec,
                                sec_per_batch,
                            )
                            print(print_str_loss)
                            summary_str = sess.run([summary_op], feed_dict=train_feed)
                            summary_writer.add_summary(summary_str[0], step)

                        if step % (display_step * 5) == 0:
                            acc_value, num_corroect = sess.run([acc, n_correct], feed_dict=train_feed)

                            format_str = "%s: step %d,  train acc = %.2f, n_correct= %d"
                            print_str_train = format_str % (datetime.now(), step, acc_value, num_corroect)
                            print(print_str_train)

                        # Save the model checkpoint periodically.
                        if (step + 1) % (display_step * 10) == 0 or (step + 1) == FLAGS.max_steps:
                            val_acc_r, val_n_correct_r = sess.run(
                                [acc, n_correct], feed_dict=val_feed
                            )  # , feed_dict=val_feed

                            frmt_str = "Step %d, Val Acc = %.2f, num correct = %d"
                            print_str_val = frmt_str % (step, val_acc_r, val_n_correct_r)
                            print(print_str_val)

                            # checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt')
                            # saver.save(sess, checkpoint_path, global_step=step, latest_filename=checkpoint_state_name)

                            checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, "saved_checkpoint")
                            saver.save(sess, checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)

            except tf.errors.OutOfRangeError:
                print("Done training -- epoch limit reached")

            finally:
                # When done, ask the threads to stop.
                coord.request_stop()

                """
         TODO #3.1: Start freezing the graph when training finished
        """
                freeze_my_graph(sess)

            # Wait for threads to finish.
            coord.join(threads)
            sess.close()

        # If you define TRAIN argument to False, so it will load from the checkpoint file and freeze.
        else:
            """
         TODO #3.2: You can also freeze the graph from the latest checkpoint if you don't want to wait for a long time.
      """
            freeze_my_graph(sess)