Пример #1
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)

    input_images = tf.placeholder(tf.float32,
                                  shape=[None, None, None, 3],
                                  name='input_images')
    input_score_maps = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 1],
                                      name='input_score_maps')
    input_geo_maps = tf.placeholder(tf.float32,
                                    shape=[None, None, None, 5],
                                    name='input_geo_maps')
    input_training_masks = tf.placeholder(tf.float32,
                                          shape=[None, None, None, 1],
                                          name='input_training_masks')
    input_transcription = tf.sparse_placeholder(tf.int32,
                                                name='input_transcription')

    input_transform_matrix = tf.placeholder(tf.float32,
                                            shape=[None, 6],
                                            name='input_transform_matrix')
    input_transform_matrix = tf.stop_gradient(input_transform_matrix)
    input_box_masks = []

    input_box_widths = tf.placeholder(tf.int32,
                                      shape=[None],
                                      name='input_box_widths')
    input_seq_len = input_box_widths[tf.argmax(
        input_box_widths, 0)] * tf.ones_like(input_box_widths)

    for i in range(FLAGS.batch_size_per_gpu):
        input_box_masks.append(
            tf.placeholder(tf.int32,
                           shape=[None],
                           name='input_box_masks_' + str(i)))

    f_score, f_geometry, recognition_logits = build_graph(
        input_images, input_transform_matrix, input_box_masks,
        input_box_widths, input_seq_len)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    # learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True)
    learning_rate = FLAGS.learning_rate
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)

    d_loss, r_loss, model_loss = compute_loss(f_score, f_geometry,
                                              recognition_logits,
                                              input_score_maps, input_geo_maps,
                                              input_training_masks,
                                              input_transcription,
                                              input_box_widths)
    # total_loss = detect_part.loss(input_score_maps, f_score, input_geo_maps, f_geometry, input_training_masks)
    tf.summary.scalar('total_loss', model_loss)
    total_loss = tf.add_n(
        [model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    # total_loss = model_loss
    batch_norm_updates_op = tf.group(
        *tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    if FLAGS.train_stage == 1:
        print("Train recognition branch only!")
        recog_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                       scope='recog')
        # grads = opt.compute_gradients(total_loss, recog_vars)
        grads = opt.compute_gradients(total_loss)
    else:
        grads = opt.compute_gradients(total_loss)
    # greds clip
    for i, (g, v) in enumerate(grads):
        if g is not None:
            grads[i] = (tf.clip_by_norm(g, 1.0), v)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())

    init = tf.global_variables_initializer()

    if FLAGS.pretrained_model_path is not None:
        if os.path.isdir(FLAGS.pretrained_model_path):
            print("Restore pretrained model from other datasets")
            ckpt = tf.train.latest_checkpoint(FLAGS.pretrained_model_path)
            variable_restore_op = slim.assign_from_checkpoint_fn(
                ckpt, slim.get_trainable_variables(), ignore_missing_vars=True)
        else:  # is *.ckpt
            print("Restore pretrained model from imagenet")
            variable_restore_op = slim.assign_from_checkpoint_fn(
                FLAGS.pretrained_model_path,
                slim.get_trainable_variables(),
                ignore_missing_vars=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            saver.restore(sess, ckpt)
        else:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                variable_restore_op(sess)

        dg = data_generator.get_batch(input_images_dir=FLAGS.training_data_dir,
                                      input_gt_dir=FLAGS.training_gt_data_dir,
                                      num_workers=FLAGS.num_readers,
                                      input_size=FLAGS.input_size,
                                      batch_size=FLAGS.batch_size_per_gpu)

        start = time.time()
        for step in range(FLAGS.max_steps):
            data = next(dg)
            inp_dict = {
                input_images: data[0],
                input_score_maps: data[2],
                input_geo_maps: data[3],
                input_training_masks: data[4],
                input_transform_matrix: data[5],
                input_box_widths: data[7],
                input_transcription: data[8]
            }

            for i in range(FLAGS.batch_size_per_gpu):
                inp_dict[input_box_masks[i]] = data[6][i]

            dl, rl, tl, _ = sess.run([d_loss, r_loss, total_loss, train_op],
                                     feed_dict=inp_dict)
            if np.isnan(tl):
                print('Loss diverged, stop training')
                break

            if step % 10 == 0:
                avg_time_per_step = (time.time() - start) / 10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu) / (
                    time.time() - start)
                start = time.time()
                print(
                    'Step {:06d}, detect_loss {:.4f}, recognize_loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                    .format(step, dl, rl, tl, avg_time_per_step,
                            avg_examples_per_second))
                """
                print "recognition results: "
                for pred in result:
                    print icdar.ground_truth_to_word(pred)
                """

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess,
                           FLAGS.checkpoint_path + 'model.ckpt',
                           global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                """
                _, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0],
                                                                                             input_score_maps: data[2],
                                                                                             input_geo_maps: data[3],
                                                                                             input_training_masks: data[4]})
                """
                dl, rl, tl, _, summary_str = sess.run(
                    [d_loss, r_loss, total_loss, train_op, summary_op],
                    feed_dict=inp_dict)

                summary_writer.add_summary(summary_str, global_step=step)
Пример #2
0
def get_data(image_dir, gt_path, voc_type, max_len, num_samples, height, width,
             batch_size, workers, keep_ratio, with_aug):
    data_list = []
    if isinstance(image_dir, list) and len(image_dir) > 1:
        # assert len(image_dir) == len(gt_path), "datasets and gt are not corresponding"
        assert batch_size % len(
            image_dir) == 0, "batch size should divide dataset num"
        per_batch_size = batch_size // len(image_dir)
        if None in gt_path:
            # Using lmdb input
            for i in image_dir:
                data_list.append(
                    lmdb_data_generator.get_batch(workers,
                                                  lmdb_dir=i,
                                                  input_height=height,
                                                  input_width=width,
                                                  batch_size=per_batch_size,
                                                  max_len=max_len,
                                                  voc_type=voc_type,
                                                  keep_ratio=keep_ratio,
                                                  with_aug=with_aug))
        else:
            for i, g in zip(image_dir, gt_path):
                data_list.append(
                    data_generator.get_batch(workers,
                                             image_dir=i,
                                             gt_path=g,
                                             input_height=height,
                                             input_width=width,
                                             batch_size=per_batch_size,
                                             max_len=max_len,
                                             voc_type=voc_type,
                                             keep_ratio=keep_ratio,
                                             with_aug=with_aug))
    else:
        if isinstance(image_dir, list):
            if None in gt_path:
                data = lmdb_data_generator.get_batch(workers,
                                                     lmdb_dir=image_dir[0],
                                                     input_height=height,
                                                     input_width=width,
                                                     batch_size=batch_size,
                                                     max_len=max_len,
                                                     voc_type=voc_type,
                                                     keep_ratio=keep_ratio,
                                                     with_aug=with_aug)
            else:
                data = data_generator.get_batch(workers,
                                                image_dir=image_dir[0],
                                                gt_path=gt_path[0],
                                                input_height=height,
                                                input_width=width,
                                                batch_size=batch_size,
                                                max_len=max_len,
                                                voc_type=voc_type,
                                                keep_ratio=keep_ratio,
                                                with_aug=with_aug)
        else:
            if gt_path is None:
                data = lmdb_data_generator.get_batch(workers,
                                                     lmdb_dir=image_dir,
                                                     input_height=height,
                                                     input_width=width,
                                                     batch_size=batch_size,
                                                     max_len=max_len,
                                                     voc_type=voc_type,
                                                     keep_ratio=keep_ratio,
                                                     with_aug=with_aug)
            else:
                data = data_generator.get_batch(workers,
                                                image_dir=image_dir,
                                                gt_path=gt_path,
                                                input_height=height,
                                                input_width=width,
                                                batch_size=batch_size,
                                                max_len=max_len,
                                                voc_type=voc_type,
                                                keep_ratio=keep_ratio,
                                                with_aug=with_aug)
        data_list.append(data)

    return data_list