Example #1
0
    def _load_data(load_to_placeholder=True):
      # read batch input
      image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
          bbox_per_batch = imdb.read_batch()

      label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
          = [], [], [], [], []
      aidx_set = set()
      num_discarded_labels = 0
      num_labels = 0
      for i in range(len(label_per_batch)): # batch_size
        for j in range(len(label_per_batch[i])): # number of annotations
          num_labels += 1
          if (i, aidx_per_batch[i][j]) not in aidx_set:
            aidx_set.add((i, aidx_per_batch[i][j]))
            label_indices.append(
                [i, aidx_per_batch[i][j], label_per_batch[i][j]])
            mask_indices.append([i, aidx_per_batch[i][j]])
            bbox_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(4)])
            box_delta_values.extend(box_delta_per_batch[i][j])
            box_values.extend(bbox_per_batch[i][j])
          else:
            num_discarded_labels += 1

      if load_to_placeholder:
        image_input = model.ph_image_input
        input_mask = model.ph_input_mask
        box_delta_input = model.ph_box_delta_input
        box_input = model.ph_box_input
        labels = model.ph_labels
      else:
        image_input = model.image_input
        input_mask = model.input_mask
        box_delta_input = model.box_delta_input
        box_input = model.box_input
        labels = model.labels

      feed_dict = {
          image_input: image_per_batch,
          input_mask: np.reshape(
              sparse_to_dense(
                  mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                  [1.0]*len(mask_indices)),
              [mc.BATCH_SIZE, mc.ANCHORS, 1]),
          box_delta_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
              box_delta_values),
          box_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
              box_values),
          labels: sparse_to_dense(
              label_indices,
              [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
              [1.0]*len(label_indices)),
      }

      return feed_dict, image_per_batch, label_per_batch, bbox_per_batch
Example #2
0
        def _load_data_per_scale(label_per_batch, box_delta_per_batch,
                                 aidx_per_batch, bbox_per_batch, s, ANCHORS):
            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                = [], [], [], [], []
            aidx_set = set()

            for i in range(len(label_per_batch)):  # batch_size
                for j in range(
                        len(aidx_per_batch[i]
                            [s])):  # number of annotations with IOU > 0.5
                    aidx, gt_id = aidx_per_batch[i][s][j]

                    if (i, aidx) not in aidx_set:
                        aidx_set.add((i, aidx))
                        label_indices.append(
                            [i, aidx, label_per_batch[i][gt_id]])
                        mask_indices.append([i, aidx])
                        bbox_indices.extend([[i, aidx, k] for k in range(4)])
                        box_delta_values.extend(box_delta_per_batch[i][s][j])
                        box_values.extend(bbox_per_batch[i][gt_id])
                    else:
                        num_discarded_labels += 1

            input_mask = np.reshape(
                sparse_to_dense(mask_indices, [mc.BATCH_SIZE, ANCHORS],
                                [1.0] * len(mask_indices)),
                [mc.BATCH_SIZE, ANCHORS, 1])
            box_delta_input = sparse_to_dense(bbox_indices,
                                              [mc.BATCH_SIZE, ANCHORS, 4],
                                              box_delta_values)
            box_input = sparse_to_dense(bbox_indices,
                                        [mc.BATCH_SIZE, ANCHORS, 4],
                                        box_values)
            labels = sparse_to_dense(label_indices,
                                     [mc.BATCH_SIZE, ANCHORS, mc.CLASSES],
                                     [1.0] * len(label_indices))

            return input_mask, box_delta_input, box_input, labels
Example #3
0
def train():
    """Train SqueezeDet model"""
    assert FLAGS.dataset == 'KITTI', \
        'Currently only support KITTI dataset'

    with tf.Graph().as_default():

        assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
            or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
            'Selected neural net architecture not supported: {}'.format(FLAGS.net)
        if FLAGS.net == 'vgg16':
            mc = kitti_vgg16_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = VGG16ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'resnet50':
            mc = kitti_res50_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = ResNet50ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'squeezeDet':
            mc = kitti_squeezeDet_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'squeezeDet+':
            mc = kitti_squeezeDetPlus_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDetPlus(mc, FLAGS.gpu)

        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

        # save model size, flops, activations by layers
        with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'),
                  'w') as f:
            f.write('Number of parameter by layer:\n')
            count = 0
            for c in model.model_size_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nActivation size by layer:\n')
            for c in model.activation_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nNumber of flops by layer:\n')
            for c in model.flop_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))
        f.close()
        print('Model statistics saved to {}.'.format(
            os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.merge_all_summaries()
        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
                bbox_per_batch = imdb.read_batch()

            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                = [], [], [], [], []
            aidx_set = set()
            num_discarded_labels = 0
            num_labels = 0
            for i in range(len(label_per_batch)):  # batch_size
                for j in range(len(
                        label_per_batch[i])):  # number of annotations
                    num_labels += 1
                    if (i, aidx_per_batch[i][j]) not in aidx_set:
                        aidx_set.add((i, aidx_per_batch[i][j]))
                        label_indices.append(
                            [i, aidx_per_batch[i][j], label_per_batch[i][j]])
                        mask_indices.append([i, aidx_per_batch[i][j]])
                        bbox_indices.extend([[i, aidx_per_batch[i][j], k]
                                             for k in range(4)])
                        box_delta_values.extend(box_delta_per_batch[i][j])
                        box_values.extend(bbox_per_batch[i][j])
                    else:
                        num_discarded_labels += 1

            if mc.DEBUG_MODE:
                print(
                    'Warning: Discarded {}/({}) labels that are assigned to the same'
                    'anchor'.format(num_discarded_labels, num_labels))

            feed_dict = {
                model.image_input:
                image_per_batch,
                model.keep_prob:
                mc.KEEP_PROB,
                model.input_mask:
                np.reshape(
                    sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                    [1.0] * len(mask_indices)),
                    [mc.BATCH_SIZE, mc.ANCHORS, 1]),
                model.box_delta_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_delta_values),
                model.box_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_values),
                model.labels:
                sparse_to_dense(label_indices,
                                [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                                [1.0] * len(label_indices)),
            }

            if step % FLAGS.summary_step == 0:
                op_list = [
                    model.train_op, model.loss, summary_op, model.det_boxes,
                    model.det_probs, model.det_class, model.conf_loss,
                    model.bbox_loss, model.class_loss
                ]
                _, loss_value, summary_str, det_boxes, det_probs, det_class, conf_loss, \
                    bbox_loss, class_loss = sess.run(op_list, feed_dict=feed_dict)

                _viz_prediction_result(model, image_per_batch, bbox_per_batch,
                                       label_per_batch, det_boxes, det_class,
                                       det_probs)
                image_per_batch = bgr_to_rgb(image_per_batch)
                viz_summary = sess.run(
                    model.viz_op,
                    feed_dict={model.image_to_show: image_per_batch})

                num_discarded_labels_op = tf.scalar_summary(
                    'counter/num_discarded_labels', num_discarded_labels)
                num_labels_op = tf.scalar_summary('counter/num_labels',
                                                  num_labels)

                counter_summary_str = sess.run(
                    [num_discarded_labels_op, num_labels_op])

                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(viz_summary, step)
                for sum_str in counter_summary_str:
                    summary_writer.add_summary(sum_str, step)

                print('conf_loss: {}, bbox_loss: {}, class_loss: {}'.format(
                    conf_loss, bbox_loss, class_loss))
            else:
                _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                    [
                        model.train_op, model.loss, model.conf_loss,
                        model.bbox_loss, model.class_loss
                    ],
                    feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), \
                'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
                'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

            if step % 10 == 0:
                num_images_per_step = mc.BATCH_SIZE
                images_per_sec = num_images_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    images_per_sec, sec_per_batch))
                sys.stdout.flush()

            # Save the model checkpoint periodically.
            if step % FLAGS.checkpoint_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
Example #4
0
        def _load_data(load_to_placeholder=True):
            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
            bbox_per_batch, image_per_batch_viz = imdb.read_batch()

            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                = [], [], [], [], []
            aidx_set = set()
            num_discarded_labels = 0
            num_labels = 0
            for i in range(len(label_per_batch)):
                for j in range(len(label_per_batch[i])):
                    num_labels += 1
                    if (i, aidx_per_batch[i][j]) not in aidx_set:
                        aidx_set.add((i, aidx_per_batch[i][j]))
                        label_indices.append(
                            [i, aidx_per_batch[i][j], label_per_batch[i][j]])
                        mask_indices.append([i, aidx_per_batch[i][j]])
                        bbox_indices.extend([[i, aidx_per_batch[i][j], k]
                                             for k in range(4)])
                        box_delta_values.extend(box_delta_per_batch[i][j])
                        box_values.extend(bbox_per_batch[i][j])
                    else:
                        num_discarded_labels += 1

            if mc.DEBUG_MODE:
                print(
                    'Warning: Discarded {}/({}) labels that are assigned to the same '
                    'anchor'.format(num_discarded_labels, num_labels))

            if load_to_placeholder:
                image_input = model.ph_image_input
                input_mask = model.ph_input_mask
                box_delta_input = model.ph_box_delta_input
                box_input = model.ph_box_input
                labels = model.ph_labels
            else:
                image_input = model.image_input
                input_mask = model.input_mask
                box_delta_input = model.box_delta_input
                box_input = model.box_input
                labels = model.labels

            feed_dict = {
                image_input:
                image_per_batch,
                input_mask:
                np.reshape(
                    sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                    [1.0] * len(mask_indices)),
                    [mc.BATCH_SIZE, mc.ANCHORS, 1]),
                box_delta_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_delta_values),
                box_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_values),
                labels:
                sparse_to_dense(label_indices,
                                [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                                [1.0] * len(label_indices)),
            }

            return feed_dict, image_per_batch, label_per_batch, bbox_per_batch, image_per_batch_viz
Example #5
0
    def _load_data(load_to_placeholder=True, eval_valid=False):
      # read batch input
      if eval_valid:
        # Only for validation set
        image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
          bbox_per_batch, edge_adhesions_per_batch = imdb_valid.read_batch(shuffle=False, wrap_around=False)
        keep_prob_value = 1.0
      else:
        image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
            bbox_per_batch, edge_adhesions_per_batch = imdb.read_batch()
        keep_prob_value = mc.DROP_OUT_PROB

      label_indices, bbox_indices, box_delta_values, mask_indices, box_values, edge_adhesions, edge_indices\
          = [], [], [], [], [], [], []
      aidx_set = set()
      num_discarded_labels = 0
      num_labels = 0
      for i in range(len(label_per_batch)): # batch_size
        for j in range(len(label_per_batch[i])): # number of annotations
          num_labels += 1
          if (i, aidx_per_batch[i][j]) not in aidx_set:
            aidx_set.add((i, aidx_per_batch[i][j]))
            label_indices.append(
                [i, aidx_per_batch[i][j], label_per_batch[i][j]])
            mask_indices.append([i, aidx_per_batch[i][j]])
            bbox_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(FLAGS.mask_parameterization)])
            box_delta_values.extend(box_delta_per_batch[i][j])
            box_values.extend(bbox_per_batch[i][j])
            edge_adhesions.extend(edge_adhesions_per_batch[i][j])
            edge_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(FLAGS.mask_parameterization)])
          else:
            num_discarded_labels += 1

      if mc.DEBUG_MODE:
        print ('Warning: Discarded {}/({}) labels that are assigned to the same '
               'anchor'.format(num_discarded_labels, num_labels))

      if load_to_placeholder:
        image_input = model.ph_image_input
        input_mask = model.ph_input_mask
        box_delta_input = model.ph_box_delta_input
        box_input = model.ph_box_input
        labels = model.ph_labels
        edge_scenarios = model.ph_edge_adhesions
      else:
        image_input = model.image_input
        input_mask = model.input_mask
        box_delta_input = model.box_delta_input
        box_input = model.box_input
        labels = model.labels
        edge_scenarios = model.edge_adhesions

      feed_dict = {
          image_input: image_per_batch,
          input_mask: np.reshape(
              sparse_to_dense(
                  mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                  [1.0]*len(mask_indices)),
              [mc.BATCH_SIZE, mc.ANCHORS, 1]),
          box_delta_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, FLAGS.mask_parameterization],
              box_delta_values),
          box_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, FLAGS.mask_parameterization],
              box_values),
          labels: sparse_to_dense(
              label_indices,
              [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
              [1.0]*len(label_indices)),
          model.keep_prob: keep_prob_value,
          edge_scenarios: sparse_to_dense(
              edge_indices, [mc.BATCH_SIZE, mc.ANCHORS, FLAGS.mask_parameterization],
              edge_adhesions).astype(np.bool, copy=False),
      }

      return feed_dict, image_per_batch, label_per_batch, bbox_per_batch, edge_indices
Example #6
0
def train():
    """Train SqueezeDet model"""
    assert FLAGS.net in ['vgg16', 'vgg16_v2', 'vgg16_v3', 'yolo_v2'], \
        'Selected neural net architecture not supported: {}'.format(FLAGS.net)
    assert FLAGS.dataset in ['KITTI', 'PASCAL_VOC', 'VID'], \
        'Either KITTI / PASCAL_VOC / VID'
    if FLAGS.dataset == 'KITTI':
        if FLAGS.net == 'yolo_v2':
            raise NotImplementedError
        else:
            mc = kitti_vgg16_config()
        mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
        train_imdb = kitti(FLAGS.train_set, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'PASCAL_VOC':
        if FLAGS.net == 'yolo_v2':
            mc = pascal_voc_yolo_config()
        else:
            mc = pascal_voc_vgg16_config()
        mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
        train_imdb = pascal_voc(FLAGS.train_set, FLAGS.year, FLAGS.data_path,
                                mc)
    elif FLAGS.dataset == 'VID':
        if FLAGS.net == 'yolo_v2':
            mc = vid_yolo_config()
        else:
            mc = vid_vgg16_config()
        mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
        train_imdb = vid(FLAGS.train_set, FLAGS.data_path, mc)

    with tf.Graph().as_default():

        if FLAGS.net == 'vgg16':
            model = VGG16ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'vgg16_v2':
            model = VGG16ConvDetV2(mc, FLAGS.gpu)
        elif FLAGS.net == 'vgg16_v3':
            model = VGG16ConvDetV3(mc, FLAGS.gpu)
        elif FLAGS.net == 'yolo_v2':
            model = YOLO_V2(mc, FLAGS.gpu)

        # save model size, flops, activations by layers
        with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'),
                  'w') as f:
            f.write('Number of parameter by layer:\n')
            count = 0
            for c in model.model_size_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nActivation size by layer:\n')
            for c in model.activation_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nNumber of flops by layer:\n')
            for c in model.flop_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))
        f.close()
        print('Model statistics saved to {}.'.format(
            os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.summary.merge_all()
        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
                bbox_per_batch = train_imdb.read_batch()

            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                num_discarded_labels, num_labels = _process_batch(image_per_batch, \
                                label_per_batch, box_delta_per_batch, \
                                aidx_per_batch, bbox_per_batch, mc.DEBUG_MODE)

            feed_dict = {
                model.image_input:
                image_per_batch,
                model.is_training:
                True,
                model.keep_prob:
                mc.KEEP_PROB,
                model.input_mask:
                np.reshape(
                    sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                    [1.0] * len(mask_indices)),
                    [mc.BATCH_SIZE, mc.ANCHORS, 1]),
                model.box_delta_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_delta_values),
                model.box_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_values),
                model.labels:
                sparse_to_dense(label_indices,
                                [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                                [1.0] * len(label_indices)),
            }

            if step % FLAGS.summary_step == 0:
                op_list = [
                    model.train_op, model.loss, summary_op, model.det_boxes,
                    model.det_probs, model.det_class, model.conf_loss,
                    model.bbox_loss, model.class_loss
                ]
                _, loss_value, summary_str, det_boxes, det_probs, det_class, conf_loss, \
                    bbox_loss, class_loss = sess.run(op_list, feed_dict=feed_dict)

                _viz_prediction_result(model, image_per_batch, bbox_per_batch,
                                       label_per_batch, det_boxes, det_class,
                                       det_probs)
                image_per_batch = bgr_to_rgb(image_per_batch)
                viz_summary = sess.run(
                    model.viz_op,
                    feed_dict={model.image_to_show: image_per_batch})

                num_discarded_labels_op = tf.summary.scalar(
                    'counter/num_discarded_labels', num_discarded_labels)
                num_labels_op = tf.summary.scalar('counter/num_labels',
                                                  num_labels)

                counter_summary_str = sess.run(
                    [num_discarded_labels_op, num_labels_op])

                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(viz_summary, step)
                for sum_str in counter_summary_str:
                    summary_writer.add_summary(sum_str, step)

                print('conf_loss: {}, bbox_loss: {}, class_loss: {}'.format(
                    conf_loss, bbox_loss, class_loss))
            else:
                _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                    [
                        model.train_op, model.loss, model.conf_loss,
                        model.bbox_loss, model.class_loss
                    ],
                    feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), \
                'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
                'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

            if step % 10 == 0:
                num_images_per_step = mc.BATCH_SIZE
                images_per_sec = num_images_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    images_per_sec, sec_per_batch))
                sys.stdout.flush()

            # Save the model checkpoint periodically.
            if step % FLAGS.checkpoint_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)