Example #1
0
def compile_graph(acc):
    input_graph = tf.Graph()
    sess = tf.Session(graph=input_graph)

    with sess.graph.as_default():
        model = framework.GraphOCR(
            RunMode.Predict,
            NETWORK_MAP[NEU_CNN],
            NETWORK_MAP[NEU_RECURRENT]
        )
        model.build_graph()
        input_graph_def = sess.graph.as_graph_def()
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))

    output_graph_def = convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names=['dense_decoded']
    )

    last_compile_model_path = COMPILE_MODEL_PATH.replace('.pb', '_{}.pb'.format(int(acc * 10000)))
    with tf.gfile.FastGFile(last_compile_model_path, mode='wb') as gf:
        gf.write(output_graph_def.SerializeToString())

    generate_config(acc)
Example #2
0
def train_process(mode=RunMode.Trains):
    model = framework.GraphOCR(mode, NETWORK_MAP[NEU_CNN],
                               NETWORK_MAP[NEU_RECURRENT])
    model.build_graph()

    print('Loading Trains DataSet...')
    train_feeder = utils.DataIterator(mode=RunMode.Trains)
    if TRAINS_USE_TFRECORDS:
        train_feeder.read_sample_from_tfrecords(TRAINS_PATH)
        print('Loading Test DataSet...')
        test_feeder = utils.DataIterator(mode=RunMode.Test)
        test_feeder.read_sample_from_tfrecords(TEST_PATH)
    else:
        if isinstance(TRAINS_PATH, list):
            origin_list = []
            for trains_path in TRAINS_PATH:
                origin_list += [
                    os.path.join(trains_path, trains)
                    for trains in os.listdir(trains_path)
                ]
        else:
            origin_list = [
                os.path.join(TRAINS_PATH, trains)
                for trains in os.listdir(TRAINS_PATH)
            ]
        random.shuffle(origin_list)
        if not HAS_TEST_SET:
            test_list = origin_list[:TEST_SET_NUM]
            trains_list = origin_list[TEST_SET_NUM:]
        else:
            if isinstance(TEST_PATH, list):
                test_list = []
                for test_path in TEST_PATH:
                    test_list += [
                        os.path.join(test_path, test)
                        for test in os.listdir(test_path)
                    ]
            else:
                test_list = [
                    os.path.join(TEST_PATH, test)
                    for test in os.listdir(TEST_PATH)
                ]
            random.shuffle(test_list)
            trains_list = origin_list
        train_feeder.read_sample_from_files(trains_list)
        print('Loading Test DataSet...')
        test_feeder = utils.DataIterator(mode=RunMode.Test)
        test_feeder.read_sample_from_files(test_list)

    print('Total {} Trains DataSets'.format(train_feeder.size))
    print('Total {} Test DataSets'.format(test_feeder.size))
    if test_feeder.size >= train_feeder.size:
        exception(
            "The number of training sets cannot be less than the test set.", )

    num_train_samples = train_feeder.size
    num_test_samples = test_feeder.size
    if num_test_samples < TEST_BATCH_SIZE:
        exception(
            "The number of test sets cannot be less than the test batch size.",
            ConfigException.INSUFFICIENT_SAMPLE)
    num_batches_per_epoch = int(num_train_samples / BATCH_SIZE)

    config = tf.ConfigProto(
        # allow_soft_placement=True,
        log_device_placement=False,
        gpu_options=tf.GPUOptions(
            allocator_type='BFC',
            allow_growth=True,  # it will cause fragmentation.
            per_process_gpu_memory_fraction=GPU_USAGE))
    accuracy = 0
    epoch_count = 1

    with tf.Session(config=config) as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        train_writer = tf.summary.FileWriter('logs', sess.graph)
        try:
            saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
        except ValueError:
            pass
        print('Start training...')

        while 1:
            shuffle_trains_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            _avg_train_cost = 0
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                index_list = [
                    shuffle_trains_idx[i % num_train_samples]
                    for i in range(cur_batch * BATCH_SIZE, (cur_batch + 1) *
                                   BATCH_SIZE)
                ]
                if TRAINS_USE_TFRECORDS:
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.generate_batch_by_tfrecords(
                        sess)
                else:
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.generate_batch_by_files(
                        index_list)

                feed = {
                    model.inputs: batch_inputs,
                    model.labels: batch_labels,
                }

                summary_str, batch_cost, step, _ = sess.run([
                    model.merged_summary, model.cost, model.global_step,
                    model.train_op
                ],
                                                            feed_dict=feed)
                train_cost += batch_cost * BATCH_SIZE
                avg_train_cost = train_cost / ((cur_batch + 1) * BATCH_SIZE)
                _avg_train_cost = avg_train_cost
                train_writer.add_summary(summary_str, step)

                if step % 100 == 0 and step != 0:
                    print('Step: {} Time: {:.3f}, Cost = {:.5f}'.format(
                        step,
                        time.time() - batch_time, avg_train_cost))

                if step % TRAINS_SAVE_STEPS == 0 and step != 0:
                    saver.save(sess, SAVE_MODEL, global_step=step)
                    logger.info('save checkpoint at step {0}', format(step))

                if step % TRAINS_VALIDATION_STEPS == 0 and step != 0:
                    shuffle_test_idx = np.random.permutation(num_test_samples)
                    batch_time = time.time()
                    index_test = [
                        shuffle_test_idx[i % num_test_samples]
                        for i in range(cur_batch *
                                       TEST_BATCH_SIZE, (cur_batch + 1) *
                                       TEST_BATCH_SIZE)
                    ]
                    if TRAINS_USE_TFRECORDS:
                        test_inputs, batch_seq_len, test_labels = test_feeder.generate_batch_by_tfrecords(
                            sess)
                    else:
                        test_inputs, batch_seq_len, test_labels = test_feeder.generate_batch_by_files(
                            index_test)

                    val_feed = {
                        model.inputs: test_inputs,
                        model.labels: test_labels
                    }
                    dense_decoded, lr = sess.run(
                        [model.dense_decoded, model.lrn_rate],
                        feed_dict=val_feed)
                    accuracy = utils.accuracy_calculation(
                        test_feeder.labels(
                            None if TRAINS_USE_TFRECORDS else index_test),
                        dense_decoded,
                        ignore_value=[0, -1],
                    )
                    log = "Epoch: {}, Step: {}, Accuracy = {:.4f}, Cost = {:.5f}, " \
                          "Time = {:.3f}, LearningRate: {}"
                    print(
                        log.format(epoch_count, step, accuracy, avg_train_cost,
                                   time.time() - batch_time, lr))

                    if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS and avg_train_cost <= TRAINS_END_COST:
                        break
            if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS and _avg_train_cost <= TRAINS_END_COST:
                compile_graph(accuracy)
                print('Total Time: {}'.format(time.time() - start_time))
                break
            epoch_count += 1