Beispiel #1
0
def evaluate(output, train_dir):
  with tf.Graph().as_default() as g:
    filename_queue=tf.train.string_input_producer([output])
    read_input=input.read_cifar10(filename_queue)
    #reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    #resized_image= tf.image.resize_image_with_crop_or_pad(reshaped_image,24,24)
    resized_image=tf.image.resize_images(read_input.uint8image,[FLAGS.input_size,FLAGS.input_size])
    float_image=tf.image.per_image_standardization(resized_image)

    min_fraction_of_examples_in_queue=0.4
    num_examples_per_epoch=FLAGS.num_examples
    min_queue_examples = int(num_examples_per_epoch * min_fraction_of_examples_in_queue)
    batch_size=128

    images, labels = input._generate_image_and_label_batch(float_image,read_input.label,min_queue_examples,batch_size,shuffle=False)
    # inference model.
    logits = core.inference(images)
    # Calculate predictions.
    top_k_op = tf.nn.top_k(tf.nn.softmax(logits), k=FLAGS.label_size)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
      core.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    return eval_once(saver, top_k_op,train_dir)
Beispiel #2
0
def eval(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            input_dict = input_pipeline.inputs(
                tfrecord_folder, dataset_split, is_training, is_vis=False, 
                batch_size=FLAGS.batch_size, num_epochs=1)
            images = input_dict['image']
            labels = input_dict['label']

        labels = tf.squeeze(labels, axis=[-1])
        logits = core.inference(FLAGS.model_variant, images, FLAGS.is_training)
        predictions = tf.argmax(
            logits, axis=-1, name='prediction', output_type=tf.int64)

        weights = tf.to_float(tf.not_equal(labels, core.IGNORE_LABEL))
        labels = tf.where(tf.equal(labels, core.IGNORE_LABEL),
                          tf.zeros_like(labels),
                          labels)
        mean_iou, update_op = tf.metrics.mean_iou(
            labels=labels, predictions=predictions,
            num_classes=core.NUMBER_CLASSES, weights=weights, name='mean_iou')

        summary_op = tf.summary.scalar('mean_iou', mean_iou)
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        num_batches = int(
            math.ceil(input_pipeline.NUMBER_VAL_DATA / float(FLAGS.batch_size)))

        # get global_step used in summary_writer.
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir=FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name.')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create gloabl_step')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir
                )) as mon_sess:
            for _ in range(num_batches):
                mon_sess.run(update_op)

            summary = mon_sess.run(summary_op)
            summary_writer.add_summary(summary, global_step=global_step)
            summary_writer.flush()
            print('*' * 60)
            print('mean_iou:', mon_sess.run(mean_iou))
            print('*' * 60)
            summary_writer.close()
Beispiel #3
0
def train():
    start_time = time.time()
    checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')

    with tf.Graph().as_default():
        with tf.name_scope('input'):
            x = tf.placeholder(tf.float32, [None, IMAGE_SIZE * IMAGE_SIZE],
                               name='x-input')
            y_ = tf.placeholder(tf.float32, [None, 4], name='y-input')
        with tf.name_scope('input_reshape'):
            image_shaped_input = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
            tf.summary.image('input', image_shaped_input, 10)
        with tf.name_scope('dropout_keep_prob'):
            keep_prob = tf.placeholder(tf.float32)
        y = core.inference(image_shaped_input, keep_prob=keep_prob)
        loss = core.loss(y, y_)
        train_op = core.train(loss, FLAGS.learning_rate)
        accuracy = core.evaluation(y, y_)
        summary = tf.summary.merge_all()
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
        sess = tf.Session()
        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
        sess.run(init)
        for i in range(20):
            images, labels = input_data.get_train_data()

            image_num, _ = images.shape
            batch_num = int(image_num / 100)
            random_index = random.sample(range(batch_num), batch_num)
            for j in range(batch_num):
                step = i * batch_num + j
                index = random_index[j]
                xs = images[index * 100:(index + 1) * 100]
                ys = labels[index * 100:(index + 1) * 100]
                if step % 50 == 0:
                    feed_dict = {x: xs, y_: ys, keep_prob: 1.0}
                    summary_str, acc = sess.run([summary, accuracy],
                                                feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    print("step %d, train accuracy %g" % (step, acc))
                else:
                    feed_dict = {
                        x: xs,
                        y_: ys,
                        keep_prob: FLAGS.dropout_keep_prob
                    }
                    summary_str, loss_value, _ = sess.run(
                        [summary, loss, train_op], feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                if (step + 1) % 1000 == 0:
                    saver.save(sess, checkpoint_file, global_step=step)
        saver.save(sess, checkpoint_file)
        duration = time.time() - start_time
        print('%d seconds' % int(duration))
Beispiel #4
0
def train_data(data_dir, train_dir):
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = core.distorted_inputs(data_dir=data_dir)
        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = core.inference(images)
        # Calculate loss.
        loss = core.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = core.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1

            def before_run(self, run_context):
                self._step += 1
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step % 10 == 0:
                    num_examples_per_step = FLAGS.batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
            if mon_sess.should_stop():
                return True
Beispiel #5
0
def eval(model_variant, tfrecord_dir, dataset_split):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            images, labels = input_pipeline.inputs(tfrecord_dir,
                                                   dataset_split,
                                                   FLAGS.is_training,
                                                   FLAGS.batch_size,
                                                   num_epochs=1)
        predictions = core.inference(model_variant, images, FLAGS.is_training)

        mean_absolute_error, update_op = tf.metrics.mean_absolute_error(
            labels=labels, predictions=predictions, name='mean_absolute_error')

        summary_op = tf.summary.scalar('eval/mean_absolute_error',
                                       mean_absolute_error)
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        num_batches = int(
            math.ceil(input_pipeline.NUMBER_VAL_DATA / FLAGS.batch_size))

        # get global_step used in summary_writer.
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir=FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name.')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create gloabl_step')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            for _ in range(num_batches):
                mon_sess.run(update_op)

            summary = mon_sess.run(summary_op)
            summary_writer.add_summary(summary, global_step=global_step)
            summary_writer.flush()
            print('*' * 50)
            print('Step {:06} mean_absolute_error:'.format(global_step),
                  mon_sess.run(mean_absolute_error))
            print('*' * 50)
            summary_writer.close()
Beispiel #6
0
def predict():
    with tf.Graph().as_default() as g:
        image = tf.gfile.FastGFile(FLAGS.image_file, 'rb').read()
        image = tf.image.decode_jpeg(image, channels=3)
        original_image = image
        image = tf.image.resize_images(image,
                                       input_pipeline.INPUT_SIZE,
                                       method=tf.image.ResizeMethod.BILINEAR,
                                       align_corners=True)

        logits = core.inference(FLAGS.model_variant,
                                tf.expand_dims(image, axis=0),
                                FLAGS.is_training)
        prediction = tf.argmax(logits, axis=-1)
        prediction = tf.image.resize_nearest_neighbor(
            tf.expand_dims(prediction, axis=-1),
            tf.shape(original_image)[:2],
            align_corners=True)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            colormap = create_pascal_label_colormap()

            original_image, prediction = mon_sess.run(
                [original_image, prediction])
            prediction = np.squeeze(prediction)
            prediction = colormap[prediction]

            plt.figure(1)
            plt.subplot(121)
            plt.imshow(original_image)

            plt.subplot(122)
            plt.imshow(prediction)

            plt.show()
Beispiel #7
0
def segment_image(image_path, image_index):
    start_time = time.time()
    with tf.Session() as sess:
        x = tf.placeholder("float", [None, IMAGE_SIZE * IMAGE_SIZE])
        x_image = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
        y_conv = tf.nn.softmax(core.inference(x_image, keep_prob=1.0))
        prediction = tf.argmax(y_conv, 1)

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver = tf.train.Saver()
            saver.restore(sess, ckpt.model_checkpoint_path)

            original_img = Image.open(image_path)
            original_img.seek(image_index)
            (width, height) = original_img.size
            result = numpy.zeros((height, width))
            for i in range(0, height):
                for j in range(0, width):
                    box = (j - CROP_SIZE, i - CROP_SIZE, j + CROP_SIZE,
                           i + CROP_SIZE)
                    image = numpy.array(original_img.crop(box))
                    image = numpy.reshape(image, (1, IMAGE_SIZE * IMAGE_SIZE))
                    image = image.astype(numpy.float32)
                    image = numpy.multiply(image, 1.0 / 255.0)
                    y_pred = prediction.eval(feed_dict={x: image})
                    print("%d, %d" % (i, j))
                    result[i][j] = y_pred
            filename = "%s_%d" % (image_path.split('\\')[-1].split('.')[0],
                                  image_index + 1)
            scipy.io.savemat(os.path.join(FLAGS.output_dir, filename),
                             mdict={'result': result})
            duration = time.time() - start_time
            print('%d seconds' % int(duration))
        else:
            print('No checkpoint file found')
            return
Beispiel #8
0
def vis(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            input_dict = input_pipeline.inputs(
                tfrecord_folder, dataset_split, is_training,
                is_vis=True, batch_size=FLAGS.batch_size, num_epochs=1)
            original_images = input_dict['original_image']
            images = input_dict['image']
            filename = input_dict['filename']

        logits = core.inference(FLAGS.model_variant, images, FLAGS.is_training)
        predictions = tf.argmax(logits, axis=-1)
        predictions = tf.expand_dims(predictions, axis=-1)
        predictions = tf.image.resize_nearest_neighbor(
            predictions, tf.shape(original_images)[1:3], align_corners=True)

        if not dataset_split in ['train', 'val', 'trainval', 'test']:
            raise ValueError('Invalid argument.')
        elif dataset_split == 'train':
            num_iters = input_pipeline.NUMBER_TRAIN_DATA
        elif dataset_split == 'val':
            num_iters = input_pipeline.NUMBER_VAL_DATA
        elif dataset_split == 'trainval':
            num_iters = input_pipeline.NUMBER_TRAINVAL_DATA

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create global_step.')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            colormap = create_pascal_label_colormap()
            cur_iter = 0
            while cur_iter < num_iters:
                (original_image, prediction, image_name) = mon_sess.run(
                    [original_images, predictions, filename])
                original_image = np.squeeze(original_image)
                prediction = np.squeeze(prediction)
                image_name = image_name[0]
                print('Visualing {}'.format(image_name))

                pil_image = Image.fromarray(original_image)
                pil_image.save(
                    '{}/{}.png'.format(FLAGS.vis_dir, image_name),
                    format='PNG')

                prediction = colormap[prediction]
                pil_prediction = Image.fromarray(prediction.astype(dtype=np.uint8))
                pil_prediction.save(
                    '{}/{}_prediction.png'.format(FLAGS.vis_dir, image_name),
                    format='PNG')

                cur_iter += 1

            print('Finished!')
Beispiel #9
0
def train(model_variant, tfrecord_dir, dataset_split):
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            images, labels = input_pipeline.inputs(tfrecord_dir,
                                                   dataset_split,
                                                   FLAGS.is_training,
                                                   FLAGS.batch_size,
                                                   num_epochs=None)

        predictions = core.inference(model_variant,
                                     images,
                                     is_training=FLAGS.is_training)

        total_loss = core.loss(predictions, labels, FLAGS.weights)

        # metric
        mean_absolute_error, update_op = tf.metrics.mean_absolute_error(
            labels=labels,
            predictions=predictions,
            updates_collections=tf.GraphKeys.UPDATE_OPS,
            name='mean_absolute_error')
        tf.summary.scalar('train/mean_absolute_error', update_op)

        steps_per_epoch = np.ceil(input_pipeline.NUMBER_TRAIN_DATA /
                                  FLAGS.batch_size)
        decay_steps = FLAGS.decay_epochs * steps_per_epoch

        learning_rate = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                                   global_step,
                                                   decay_steps,
                                                   FLAGS.decay_rate,
                                                   staircase=FLAGS.staircase)
        tf.summary.scalar('learning_rate', learning_rate)

        with tf.variable_scope('adam_vars'):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            # update moving mean/var in batch_norm layers
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss, global_step)

        adam_vars = optimizer.variables()
        # def name_in_checkpoint(var):
        #     return var.op.name.replace(FLAGS.model_variant, 'vgg_16')
        variables_to_restore = slim.get_variables_to_restore(
            exclude=(models.EXCLUDE_LIST_MAP[FLAGS.model_variant] +
                     ['global_step', 'adam_vars']))
        # variables_to_restore = {name_in_checkpoint(var):var
        #     for var in variables_to_restore
        #         if not 'BatchNorm' in var.op.name}
        # variables_to_restore = {name_in_checkpoint(var):var
        #     for var in variables_to_restore}

        if FLAGS.restore_ckpt_path:
            restorer = tf.train.Saver(variables_to_restore)

            def init_fn(scaffold, sess):
                restorer.restore(sess, FLAGS.restore_ckpt_path)
        else:
            init_fn = None

        class _LoggerHook(tf.train.SessionRunHook):
            def begin(self):
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    self._step = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1]) - 1
                else:
                    self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    total_loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = (FLAGS.log_frequency *
                                        FLAGS.batch_size / duration)
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.5f '
                                  '(%.1f examples/sec; %.3f sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        training_steps = steps_per_epoch * FLAGS.num_epochs

        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=tf.train.Scaffold(init_fn=init_fn),
                hooks=[
                    tf.train.StopAtStepHook(last_step=training_steps),
                    tf.train.NanTensorHook(total_loss),
                    _LoggerHook()
                ],
                config=config,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
def vis(model_variant, tfrecord_dir, dataset_split):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            images, labels = input_pipeline.inputs(tfrecord_dir,
                                                   dataset_split,
                                                   FLAGS.is_training,
                                                   FLAGS.batch_size,
                                                   num_epochs=1)
        predictions = core.inference(model_variant,
                                     images,
                                     is_training=FLAGS.is_training)

        if not dataset_split in ['train', 'val', 'test']:
            raise Exception('Invalid argument.')
        elif dataset_split == 'train':
            num_iters = input_pipeline.NUMBER_TRAIN_DATA
        elif dataset_split == 'val':
            num_iters = input_pipeline.NUMBER_VAL_DATA
        elif dataset_split == 'test':
            num_iters = input_pipeline.NUMBER_TEST_DATA

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create global_step.')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            cur_iter = 0
            while cur_iter < num_iters:
                print('Visualizing {:06d}.png'.format(cur_iter))
                image, labels_, predictions_ = mon_sess.run(
                    [images, labels, predictions])
                # image shape is (224, 224, 3)
                image = np.squeeze(image)
                image = np.uint8(image * 255.0)
                labels_ = np.squeeze(labels_)

                predictions_ = predictions_ * input_pipeline.IMAGE_SHAPE[0]
                labels_ = labels_ * input_pipeline.IMAGE_SHAPE[0]

                predictions_ = list(
                    np.split(predictions_, indices_or_sections=2))
                labels_ = list(np.split(labels_, indices_or_sections=2))
                pil_image = Image.fromarray(image)
                draw = ImageDraw.Draw(pil_image)
                for prediction, label in zip(predictions_, labels_):
                    prediction = list(prediction)
                    label = list(label)
                    # top_left -> top_right -> bottom_right -> bottom_left
                    prediction = prediction[:4] + prediction[-2:] + prediction[
                        4:6]
                    label = label[:4] + label[-2:] + label[4:6]
                    # point is represented as (width, height) in PIL
                    prediction = (prediction[0:2][::-1] +
                                  prediction[2:4][::-1] +
                                  prediction[4:6][::-1] +
                                  prediction[-2:][::-1])
                    label = (label[0:2][::-1] + label[2:4][::-1] +
                             label[4:6][::-1] + label[-2:][::-1])

                    draw.polygon(label, outline=(0, 255, 0))
                    draw.polygon(prediction, outline=(255, 0, 0))

                pil_image.save('{}/{:06}.png'.format(FLAGS.vis_dir, cur_iter),
                               format='PNG')

                cur_iter += 1

            print('Finished!')
Beispiel #11
0
def train(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            input_data = input_pipeline.inputs(
                tfrecord_folder, dataset_split, is_training, is_vis=False,
                batch_size=FLAGS.batch_size, num_epochs=None)
            images = input_data['image']
            labels = input_data['label']

        tf.summary.image('images', images)
        tf.summary.image('labels', tf.cast(labels, tf.uint8))

        logits = core.inference(FLAGS.model_variant, images, FLAGS.is_training)
        total_loss = core.loss(logits, labels)

        tf.summary.histogram('logits', logits)

        learning_rate = tf.train.exponential_decay(
            FLAGS.initial_learning_rate, global_step, FLAGS.decay_steps,
            FLAGS.decay_rate, staircase=FLAGS.staircase)

        tf.summary.scalar('learning_rate', learning_rate)
        for var in tf.model_variables():
            tf.summary.histogram(var.op.name, var)

        with tf.variable_scope('adam'):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss, global_step)

        variables_to_restore = slim.get_variables_to_restore(
            exclude=(core.MODEL_MAP[FLAGS.model_variant].exclude_list()
                     + ['adam', 'global_step']))
        def name_in_checkpoint(var):
            return var.op.name.replace(FLAGS.model_variant + '/', '')
        variables_to_restore = {
            name_in_checkpoint(var):var for var in variables_to_restore
            if 'vgg_16' in var.op.name
        }

        restorer = tf.train.Saver(variables_to_restore)
        def init_fn(scaffold, sess):
            restorer.restore(sess, FLAGS.restore_ckpt_path)

        class _LoggerHook(tf.train.SessionRunHook):
            def begin(self):
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    self._step = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) - 1
                else:
                    self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                loss_dict = {
                    'cross_entropy': g.get_tensor_by_name(
                        'cross_entropy_loss:0'),
                    'regularization': g.get_tensor_by_name(
                        'regularization_loss:0'),
                    'total': g.get_tensor_by_name('total_loss:0')
                }
                return tf.train.SessionRunArgs(loss_dict)

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_dict = run_values.results
                    examples_per_sec = (FLAGS.log_frequency
                                        * FLAGS.batch_size / duration)
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, '
                                  'total_loss = %.5f, '
                                  'cross_entropy = %.5f, '
                                  'regularization = %.5f, '
                                  '(%.1f examples/sec; %.3f sec/batch)')
                    print(format_str % (datetime.now(), self._step,
                                        loss_dict['total'],
                                        loss_dict['cross_entropy'],
                                        loss_dict['regularization'],
                                        examples_per_sec, sec_per_batch))

        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=tf.train.Scaffold(init_fn=init_fn),
                hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                       tf.train.NanTensorHook(total_loss),
                       _LoggerHook()],
                config=config,
                save_summaries_steps=FLAGS.save_summaries_steps,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)