Example #1
0
def eval(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            input_dict = input_pipeline.inputs(
                tfrecord_folder, dataset_split, is_training, is_vis=False, 
                batch_size=FLAGS.batch_size, num_epochs=1)
            images = input_dict['image']
            labels = input_dict['label']

        labels = tf.squeeze(labels, axis=[-1])
        logits = core.inference(FLAGS.model_variant, images, FLAGS.is_training)
        predictions = tf.argmax(
            logits, axis=-1, name='prediction', output_type=tf.int64)

        weights = tf.to_float(tf.not_equal(labels, core.IGNORE_LABEL))
        labels = tf.where(tf.equal(labels, core.IGNORE_LABEL),
                          tf.zeros_like(labels),
                          labels)
        mean_iou, update_op = tf.metrics.mean_iou(
            labels=labels, predictions=predictions,
            num_classes=core.NUMBER_CLASSES, weights=weights, name='mean_iou')

        summary_op = tf.summary.scalar('mean_iou', mean_iou)
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        num_batches = int(
            math.ceil(input_pipeline.NUMBER_VAL_DATA / float(FLAGS.batch_size)))

        # get global_step used in summary_writer.
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir=FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name.')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create gloabl_step')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir
                )) as mon_sess:
            for _ in range(num_batches):
                mon_sess.run(update_op)

            summary = mon_sess.run(summary_op)
            summary_writer.add_summary(summary, global_step=global_step)
            summary_writer.flush()
            print('*' * 60)
            print('mean_iou:', mon_sess.run(mean_iou))
            print('*' * 60)
            summary_writer.close()
Example #2
0
def eval(model_variant, tfrecord_dir, dataset_split):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            images, labels = input_pipeline.inputs(tfrecord_dir,
                                                   dataset_split,
                                                   FLAGS.is_training,
                                                   FLAGS.batch_size,
                                                   num_epochs=1)
        predictions = core.inference(model_variant, images, FLAGS.is_training)

        mean_absolute_error, update_op = tf.metrics.mean_absolute_error(
            labels=labels, predictions=predictions, name='mean_absolute_error')

        summary_op = tf.summary.scalar('eval/mean_absolute_error',
                                       mean_absolute_error)
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        num_batches = int(
            math.ceil(input_pipeline.NUMBER_VAL_DATA / FLAGS.batch_size))

        # get global_step used in summary_writer.
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir=FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name.')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create gloabl_step')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            for _ in range(num_batches):
                mon_sess.run(update_op)

            summary = mon_sess.run(summary_op)
            summary_writer.add_summary(summary, global_step=global_step)
            summary_writer.flush()
            print('*' * 50)
            print('Step {:06} mean_absolute_error:'.format(global_step),
                  mon_sess.run(mean_absolute_error))
            print('*' * 50)
            summary_writer.close()
Example #3
0
def main(tfrecord_dir, dataset_split):
    data = input_pipeline.inputs(tfrecord_dir, dataset_split, True, 1, 1)
    image, label = data
    image = tf.squeeze(image)
    label = tf.squeeze(label)

    image = image.numpy()
    label = label.numpy()

    image *= 255
    image = image.astype(np.uint8)
    image = Image.fromarray(image)
    draw = ImageDraw.Draw(image)
    label *= 224
    label = label.astype(np.int32)

    label = np.split(label, 2, axis=0)
    for _label in label:
        _label = list(_label)
        _label = (_label[0:2][::-1] + _label[2:4][::-1] + _label[-2:][::-1] +
                  _label[4:6][::-1])
        draw.polygon(_label, outline=(255, 0, 0))

    image.save(os.path.join(SAVE_DIR, 'test.png'), 'PNG')
Example #4
0
def vis(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            input_dict = input_pipeline.inputs(
                tfrecord_folder, dataset_split, is_training,
                is_vis=True, batch_size=FLAGS.batch_size, num_epochs=1)
            original_images = input_dict['original_image']
            images = input_dict['image']
            filename = input_dict['filename']

        logits = core.inference(FLAGS.model_variant, images, FLAGS.is_training)
        predictions = tf.argmax(logits, axis=-1)
        predictions = tf.expand_dims(predictions, axis=-1)
        predictions = tf.image.resize_nearest_neighbor(
            predictions, tf.shape(original_images)[1:3], align_corners=True)

        if not dataset_split in ['train', 'val', 'trainval', 'test']:
            raise ValueError('Invalid argument.')
        elif dataset_split == 'train':
            num_iters = input_pipeline.NUMBER_TRAIN_DATA
        elif dataset_split == 'val':
            num_iters = input_pipeline.NUMBER_VAL_DATA
        elif dataset_split == 'trainval':
            num_iters = input_pipeline.NUMBER_TRAINVAL_DATA

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create global_step.')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            colormap = create_pascal_label_colormap()
            cur_iter = 0
            while cur_iter < num_iters:
                (original_image, prediction, image_name) = mon_sess.run(
                    [original_images, predictions, filename])
                original_image = np.squeeze(original_image)
                prediction = np.squeeze(prediction)
                image_name = image_name[0]
                print('Visualing {}'.format(image_name))

                pil_image = Image.fromarray(original_image)
                pil_image.save(
                    '{}/{}.png'.format(FLAGS.vis_dir, image_name),
                    format='PNG')

                prediction = colormap[prediction]
                pil_prediction = Image.fromarray(prediction.astype(dtype=np.uint8))
                pil_prediction.save(
                    '{}/{}_prediction.png'.format(FLAGS.vis_dir, image_name),
                    format='PNG')

                cur_iter += 1

            print('Finished!')
Example #5
0
def train(model_variant, tfrecord_dir, dataset_split):
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            images, labels = input_pipeline.inputs(tfrecord_dir,
                                                   dataset_split,
                                                   FLAGS.is_training,
                                                   FLAGS.batch_size,
                                                   num_epochs=None)

        predictions = core.inference(model_variant,
                                     images,
                                     is_training=FLAGS.is_training)

        total_loss = core.loss(predictions, labels, FLAGS.weights)

        # metric
        mean_absolute_error, update_op = tf.metrics.mean_absolute_error(
            labels=labels,
            predictions=predictions,
            updates_collections=tf.GraphKeys.UPDATE_OPS,
            name='mean_absolute_error')
        tf.summary.scalar('train/mean_absolute_error', update_op)

        steps_per_epoch = np.ceil(input_pipeline.NUMBER_TRAIN_DATA /
                                  FLAGS.batch_size)
        decay_steps = FLAGS.decay_epochs * steps_per_epoch

        learning_rate = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                                   global_step,
                                                   decay_steps,
                                                   FLAGS.decay_rate,
                                                   staircase=FLAGS.staircase)
        tf.summary.scalar('learning_rate', learning_rate)

        with tf.variable_scope('adam_vars'):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            # update moving mean/var in batch_norm layers
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss, global_step)

        adam_vars = optimizer.variables()
        # def name_in_checkpoint(var):
        #     return var.op.name.replace(FLAGS.model_variant, 'vgg_16')
        variables_to_restore = slim.get_variables_to_restore(
            exclude=(models.EXCLUDE_LIST_MAP[FLAGS.model_variant] +
                     ['global_step', 'adam_vars']))
        # variables_to_restore = {name_in_checkpoint(var):var
        #     for var in variables_to_restore
        #         if not 'BatchNorm' in var.op.name}
        # variables_to_restore = {name_in_checkpoint(var):var
        #     for var in variables_to_restore}

        if FLAGS.restore_ckpt_path:
            restorer = tf.train.Saver(variables_to_restore)

            def init_fn(scaffold, sess):
                restorer.restore(sess, FLAGS.restore_ckpt_path)
        else:
            init_fn = None

        class _LoggerHook(tf.train.SessionRunHook):
            def begin(self):
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    self._step = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1]) - 1
                else:
                    self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    total_loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = (FLAGS.log_frequency *
                                        FLAGS.batch_size / duration)
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.5f '
                                  '(%.1f examples/sec; %.3f sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        training_steps = steps_per_epoch * FLAGS.num_epochs

        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=tf.train.Scaffold(init_fn=init_fn),
                hooks=[
                    tf.train.StopAtStepHook(last_step=training_steps),
                    tf.train.NanTensorHook(total_loss),
                    _LoggerHook()
                ],
                config=config,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Example #6
0
def train(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            input_data = input_pipeline.inputs(tfrecord_folder,
                                               dataset_split,
                                               is_training,
                                               is_vis=False,
                                               batch_size=FLAGS.batch_size,
                                               num_epochs=None)
            images = input_data['image']
            labels = input_data['label']

            tf.summary.image('images', images)
            tf.summary.image('labels', tf.cast(labels, tf.uint8))

        num_batches_per_epoch = (input_pipeline.NUMBER_TRAIN_DATA /
                                 FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)
        learning_rate = tf.train.exponential_decay(
            FLAGS.initial_learning_rate,
            global_step,
            decay_steps,
            FLAGS.learning_rate_decay_factor,
            staircase=FLAGS.staircase)
        tf.summary.scalar('learning_rate', learning_rate)

        # TODO(hhw): Change to adam optimizer.
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)

        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            for i in xrange(FLAGS.num_gpus):
                with tf.device('/gpu:{}'.format(i)):
                    with tf.name_scope('{}_{}'.format('tower', i)) as scope:
                        loss = core.tower_loss(FLAGS.model_variant, images,
                                               labels, is_training, scope)

                        grads = optimizer.compute_gradients(
                            loss, tf.trainable_variables())

                        tower_grads.append(grads)

        grads = core.average_gradients(tower_grads)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.apply_gradients(grads,
                                                 global_step=global_step)

        if FLAGS.use_init_model:

            def name_in_checkpoint(var):
                return var.op.name.replace(FLAGS.model_variant + '/', '')

            variables_to_restore = slim.get_variables_to_restore(
                exclude=(core.MODEL_MAP[FLAGS.model_variant].exclude_list() +
                         ['global_step', 'adam']))
            variables_to_restore = {
                name_in_checkpoint(var): var
                for var in variables_to_restore if 'vgg_16' in var.op.name
            }

            restorer = tf.train.Saver(variables_to_restore)

            def init_fn(scaffold, sess):
                restorer.restore(sess, FLAGS.restore_ckpt_path)
        else:
            init_fn = None

        class _LoggerHook(tf.train.SessionRunHook):
            def begin(self):
                ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    self._step = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1]) - 1
                else:
                    self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                loss_dict = {
                    'cross_entropy':
                    g.get_tensor_by_name(
                        'tower_{}/cross_entropy_loss:0'.format(0)),
                    'regularization':
                    g.get_tensor_by_name(
                        'tower_{}/regularization_loss:0'.format(0)),
                    'total':
                    g.get_tensor_by_name('tower_{}/total_loss:0'.format(0)),
                }
                return tf.train.SessionRunArgs(loss_dict)

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_dict = run_values.results
                    examples_per_sec = (FLAGS.log_frequency *
                                        FLAGS.batch_size / duration)
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('{}: step {}, '
                                  'tower_loss = {:7.5f}, '
                                  'cross_entropy_loss = {:7.5f}, '
                                  'regularization_loss = {:7.5f}, '
                                  '({:5.3f} examples/sec; {:02.3} sec/batch)')
                    print(
                        format_str.format(datetime.now(), self._step,
                                          loss_dict['total'],
                                          loss_dict['cross_entropy'],
                                          loss_dict['regularization'],
                                          examples_per_sec, sec_per_batch))

        num_train_steps = int(num_batches_per_epoch * FLAGS.num_epochs)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=tf.train.Scaffold(init_fn=init_fn),
                hooks=[
                    tf.train.StopAtStepHook(last_step=num_train_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=config,
                save_summaries_steps=FLAGS.save_summaries_steps,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Example #7
0
def vis(model_variant, tfrecord_dir, dataset_split):
    with tf.Graph().as_default() as g:
        with tf.device('/cpu:0'):
            images, labels = input_pipeline.inputs(tfrecord_dir,
                                                   dataset_split,
                                                   FLAGS.is_training,
                                                   FLAGS.batch_size,
                                                   num_epochs=1)
        predictions = core.inference(model_variant,
                                     images,
                                     is_training=FLAGS.is_training)

        if not dataset_split in ['train', 'val', 'test']:
            raise Exception('Invalid argument.')
        elif dataset_split == 'train':
            num_iters = input_pipeline.NUMBER_TRAIN_DATA
        elif dataset_split == 'val':
            num_iters = input_pipeline.NUMBER_VAL_DATA
        elif dataset_split == 'test':
            num_iters = input_pipeline.NUMBER_TEST_DATA

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print('Get global_step from checkpoint name')
        else:
            global_step = tf.train.get_or_create_global_step()
            print('Create global_step.')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredSession(
                session_creator=tf.train.ChiefSessionCreator(
                    config=config,
                    checkpoint_dir=FLAGS.checkpoint_dir)) as mon_sess:
            cur_iter = 0
            while cur_iter < num_iters:
                print('Visualizing {:06d}.png'.format(cur_iter))
                image, labels_, predictions_ = mon_sess.run(
                    [images, labels, predictions])
                # image shape is (224, 224, 3)
                image = np.squeeze(image)
                image = np.uint8(image * 255.0)
                labels_ = np.squeeze(labels_)

                predictions_ = predictions_ * input_pipeline.IMAGE_SHAPE[0]
                labels_ = labels_ * input_pipeline.IMAGE_SHAPE[0]

                predictions_ = list(
                    np.split(predictions_, indices_or_sections=2))
                labels_ = list(np.split(labels_, indices_or_sections=2))
                pil_image = Image.fromarray(image)
                draw = ImageDraw.Draw(pil_image)
                for prediction, label in zip(predictions_, labels_):
                    prediction = list(prediction)
                    label = list(label)
                    # top_left -> top_right -> bottom_right -> bottom_left
                    prediction = prediction[:4] + prediction[-2:] + prediction[
                        4:6]
                    label = label[:4] + label[-2:] + label[4:6]
                    # point is represented as (width, height) in PIL
                    prediction = (prediction[0:2][::-1] +
                                  prediction[2:4][::-1] +
                                  prediction[4:6][::-1] +
                                  prediction[-2:][::-1])
                    label = (label[0:2][::-1] + label[2:4][::-1] +
                             label[4:6][::-1] + label[-2:][::-1])

                    draw.polygon(label, outline=(0, 255, 0))
                    draw.polygon(prediction, outline=(255, 0, 0))

                pil_image.save('{}/{:06}.png'.format(FLAGS.vis_dir, cur_iter),
                               format='PNG')

                cur_iter += 1

            print('Finished!')
Example #8
0
def train(tfrecord_folder, dataset_split, is_training):
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        with tf.device('/cpu:0'):
            input_data = input_pipeline.inputs(
                tfrecord_folder, dataset_split, is_training, is_vis=False,
                batch_size=FLAGS.batch_size, num_epochs=None)
            images = input_data['image']
            labels = input_data['label']

        tf.summary.image('images', images)
        tf.summary.image('labels', tf.cast(labels, tf.uint8))

        logits = core.inference(FLAGS.model_variant, images, FLAGS.is_training)
        total_loss = core.loss(logits, labels)

        tf.summary.histogram('logits', logits)

        learning_rate = tf.train.exponential_decay(
            FLAGS.initial_learning_rate, global_step, FLAGS.decay_steps,
            FLAGS.decay_rate, staircase=FLAGS.staircase)

        tf.summary.scalar('learning_rate', learning_rate)
        for var in tf.model_variables():
            tf.summary.histogram(var.op.name, var)

        with tf.variable_scope('adam'):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss, global_step)

        variables_to_restore = slim.get_variables_to_restore(
            exclude=(core.MODEL_MAP[FLAGS.model_variant].exclude_list()
                     + ['adam', 'global_step']))
        def name_in_checkpoint(var):
            return var.op.name.replace(FLAGS.model_variant + '/', '')
        variables_to_restore = {
            name_in_checkpoint(var):var for var in variables_to_restore
            if 'vgg_16' in var.op.name
        }

        restorer = tf.train.Saver(variables_to_restore)
        def init_fn(scaffold, sess):
            restorer.restore(sess, FLAGS.restore_ckpt_path)

        class _LoggerHook(tf.train.SessionRunHook):
            def begin(self):
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    self._step = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) - 1
                else:
                    self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                loss_dict = {
                    'cross_entropy': g.get_tensor_by_name(
                        'cross_entropy_loss:0'),
                    'regularization': g.get_tensor_by_name(
                        'regularization_loss:0'),
                    'total': g.get_tensor_by_name('total_loss:0')
                }
                return tf.train.SessionRunArgs(loss_dict)

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_dict = run_values.results
                    examples_per_sec = (FLAGS.log_frequency
                                        * FLAGS.batch_size / duration)
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, '
                                  'total_loss = %.5f, '
                                  'cross_entropy = %.5f, '
                                  'regularization = %.5f, '
                                  '(%.1f examples/sec; %.3f sec/batch)')
                    print(format_str % (datetime.now(), self._step,
                                        loss_dict['total'],
                                        loss_dict['cross_entropy'],
                                        loss_dict['regularization'],
                                        examples_per_sec, sec_per_batch))

        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=tf.train.Scaffold(init_fn=init_fn),
                hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                       tf.train.NanTensorHook(total_loss),
                       _LoggerHook()],
                config=config,
                save_summaries_steps=FLAGS.save_summaries_steps,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)