Esempio n. 1
0
def acc():
    acc_all = []
    g = model_transfer.Graph(pb_file_path = pb_file_path)
    with tf.Session(graph = g.graph) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(),max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess,ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore') 
        test_feeder=utils.DataIterator2(data_dir=data_dir,text_dir=text_dir)
        print("total data:",test_feeder.size)
        print("total image in folder", test_feeder.total_pic_read)
        total_epoch = int(test_feeder.size / FLAGS.batch_size) + 1
        for cur_batch in range(total_epoch):
            print("cur_epoch/total_epoch",cur_batch,"/",total_epoch)
            indexs=[]
            cur_batch_num = FLAGS.batch_size
            if cur_batch == int(test_feeder.size / FLAGS.batch_size):
                cur_batch_num = test_feeder.size - cur_batch * FLAGS.batch_size 
            for i in range(cur_batch_num):
                indexs.append(cur_batch * FLAGS.batch_size + i) 
            test_inputs,test_seq_len,test_labels=test_feeder.input_index_generate_batch(indexs)
            cur_labels = [test_feeder.labels[i] for i in indexs]
            test_feed={g.original_pic: test_inputs,
                      g.labels: test_labels,
                      g.seq_len: np.array([Flage_width]*test_inputs.shape[0])}
            dense_decoded= sess.run(g.dense_decoded, test_feed)
            acc = utils.accuracy_calculation(cur_labels,dense_decoded,ignore_value=-1,isPrint=False)
            acc_all.append(acc)
        print("$$$$$$$$$$$$$$$$$ ACC is :",acc_all,"$$$$$$$$$$$$$$$$$")
        print("avg_acc:",np.array(acc_all).mean()) 
Esempio n. 2
0
def test(val_dir=data_dir, val_text_dir=text_dir):
    g = model.Graph(is_training=True)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('***************get image: ', val_feeder.size)

    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        print(
            '=============================begin testing============================='
        )
        if True:
            if True:
                if True:
                    acc_avg = 0.0
                    for cur_batch_cv in range(num_val_per_epoch):
                        print(num_val_per_epoch)
                        index_cv = []
                        for i in range(FLAGS.batch_size):
                            index_cv.append(cur_batch_cv * FLAGS.batch_size +
                                            i)
                        #print("index_cv",index_cv)
                        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
                            index_cv)
                        val_feed = {
                            g.inputs: val_inputs,
                            g.labels: val_labels,
                            g.keep_prob_cv: 1
                        }
                        predict_word_index, lr = sess.run(
                            [g.logits, g.learning_rate], val_feed)
                        print(val_labels[0], predict_word_index[0])
                        acc = utils.compute_acc(val_labels, predict_word_index)
                        acc_avg += acc
                    acc_avg = acc_avg / num_val_per_epoch
                    print("acc", acc_avg)
Esempio n. 3
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None):
    g = model.Graph(is_training=True)
    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator2(data_dir=train_dir,
                                       text_dir=train_text_dir)
    print('get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.6
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
        g.graph.finalize()
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        #print(len(val_inputs))
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([g.cnn_time] * val_inputs.shape[0])
        }
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()
            #the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                #batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {
                    g.inputs: batch_inputs,
                    g.labels: batch_labels,
                    g.seq_len: np.array([g.cnn_time] * batch_inputs.shape[0])
                }

                # if summary is needed
                #batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)
                summary_str, batch_cost, step, _ = sess.run(
                    [g.merged_summay, g.cost, g.global_step, g.optimizer],
                    feed)
                #calculate the cost
                train_cost += batch_cost * FLAGS.batch_size
                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 0:
                    print("save checkpoint", step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)
                #train_err+=the_err*FLAGS.batch_size
                #do validation
                if step % FLAGS.validation_steps == 0:
                    dense_decoded, lastbatch_err, lr = sess.run(
                        [g.dense_decoded, g.lerr, g.learning_rate], val_feed)
                    # print the decode result
                    acc = utils.accuracy_calculation(val_feeder.labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    #train_err/=num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{}  step==={}, Epoch {}/{}, accuracy = {:.3f},avg_train_cost = {:.3f}, lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}\n"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, step, cur_epoch + 1,
                                   FLAGS.num_epochs, acc, avg_train_cost,
                                   lastbatch_err,
                                   time.time() - start_time, lr))
                    if Flag_Isserver:
                        f = open('../log/acc/acc.txt', mode="a")
                        f.write(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, step,
                                       cur_epoch + 1, FLAGS.num_epochs, acc,
                                       avg_train_cost, lastbatch_err,
                                       time.time() - start_time, lr))
                        f.close()
Esempio n. 4
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None):
    acc_avg = 0.0
    acc_best = 0.0
    acc_best_step = 0
    g = model.Graph(is_training=True)
    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator2(data_dir=train_dir,
                                       text_dir=train_text_dir)
    print('***************get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('***************get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        g.graph.finalize()
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            for cur_batch in range(num_batches_per_epoch):
                now = datetime.datetime.now()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                feed = {
                    g.inputs: batch_inputs,
                    g.labels: batch_labels,
                    g.keep_prob_cv: FLAGS.train_keep_prob_cv
                }

                batch_cost_avg, step, _, predict_result = sess.run(
                    [g.cost_batch_avg, g.global_step, g.optimizer, g.logits],
                    feed)
                if step % 1 == 0:
                    print("cur_epoch====", cur_epoch, "cur_batch----",
                          cur_batch, "g_step****", step, "cost",
                          batch_cost_avg)
                    if False:
                        print("real", batch_labels)
                        print("predict", predict_result)

                if step % FLAGS.validation_steps == 0:
                    acc_avg = 0.0
                    for cur_batch_cv in range(num_val_per_epoch):
                        print(num_val_per_epoch)
                        index_cv = []
                        for i in range(FLAGS.batch_size):
                            index_cv.append(cur_batch_cv * FLAGS.batch_size +
                                            i)
                        print("index_cv", index_cv)
                        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
                            index_cv)
                        val_feed = {
                            g.inputs: val_inputs,
                            g.labels: val_labels,
                            g.keep_prob_cv: 1
                        }
                        predict_word_index, lr = sess.run(
                            [g.logits, g.learning_rate], val_feed)
                        acc = utils.compute_acc(val_labels, predict_word_index)
                        acc_avg += acc
                    acc_avg = acc_avg / num_val_per_epoch
                    if acc_avg - acc_best > 0.00001:
                        acc_best = acc_avg
                        acc_best_step = step

                    print("acc", acc_avg)
                    if Flag_Isserver:
                        f = open('../log/acc/acc.txt', mode="a")
                        log = "{}/{} {}:{}:{}, Epoch {}/{}, step=={}-->cur_acc = {:.3f}, best_step=={}-->best_acc = {:.3f}, lr={:.8f},batch_cost_avg={:.3f}\n"
                        f.write(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, step, acc_avg,
                                       acc_best_step, acc_best, lr,
                                       batch_cost_avg))
                        f.close()

                if step % FLAGS.save_steps == 0 or acc_avg - acc_best > 0.0001:
                    print("save checkpoint", step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)
Esempio n. 5
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None,
          pb_file_path=None):
    g = model_transfer.Graph(is_training=True, pb_file_path=pb_file_path)

    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator2(data_dir=train_dir,
                                       text_dir=train_text_dir)
    print('***************get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('***************get image: ', val_feeder.size)
    '''
    print('loading train data, please wait---------------------','end= ')
    train_feeder=utils.DataIterator(data_dir=train_dir)
    print('get image: ',train_feeder.size)
    print('loading validation data, please wait---------------------','end= ')
    val_feeder=utils.DataIterator(data_dir=val_dir)
    print('get image: ',val_feeder.size)
    '''
    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        g.graph.finalize()
        print(
            '=============================begin training============================='
        )
        cur_training_step = 0
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.original_pic: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([Flage_width] * val_inputs.shape[0])
        }

        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            for cur_batch in range(num_batches_per_epoch):
                cur_training_step += 1
                if cur_training_step % Flage_print_frequency == 0:
                    print("epochs", cur_epoch, cur_batch)

                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                transfer_train_batch_feed = {
                    g.original_pic: batch_inputs,
                    g.seq_len: np.array([Flage_width] * batch_inputs.shape[0]),
                    g.labels: batch_labels
                }
                summary_str, batch_cost, all_step, _ = sess.run(
                    [g.merged_summay, g.cost, g.global_step, g.optimizer],
                    transfer_train_batch_feed)
                train_cost += batch_cost * FLAGS.batch_size

                if all_step % FLAGS.save_steps == 0:
                    print("**********save checkpoint********** all_step:",
                          all_step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=all_step)

                if all_step % FLAGS.validation_steps == 0:
                    print("**********CrossValidation********** all_step:",
                          all_step)
                    dense_decoded, lastbatch_err, lr = sess.run(
                        [g.dense_decoded, g.lerr, g.learning_rate], val_feed)
                    acc = utils.accuracy_calculation(val_feeder.labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{}  all_step==={}, Epoch {}/{}, accuracy = {:.3f},avg_train_cost = {:.3f}, lastbatch_err = {:.3f},lr={:.8f}\n"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, all_step, cur_epoch + 1,
                                   FLAGS.num_epochs, acc, avg_train_cost,
                                   lastbatch_err, lr))
                    if Flag_Isserver:
                        f = open('../log/acc/acc.txt', mode="a")
                        f.write(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, all_step,
                                       cur_epoch + 1, FLAGS.num_epochs, acc,
                                       avg_train_cost, lastbatch_err, lr))
                        f.close()
Esempio n. 6
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None):
    g = model.Graph(is_training=True)
    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator2(data_dir=train_dir,
                                       text_dir=train_text_dir)
    print('***************get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('***************get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        g.graph.finalize()
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        val_inputs, val_seq_len, val_labels, val_labels_len = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.inputs: val_inputs,
            g.y_: val_labels_len,
            g.keep_prob_fc: 1,
            g.keep_prob_cv1: 1
        }

        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            for cur_batch in range(num_batches_per_epoch):
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels, batch_labels_len = train_feeder.input_index_generate_batch(
                    indexs)
                feed = {
                    g.inputs: batch_inputs,
                    g.y_: batch_labels_len,
                    g.keep_prob_fc: FLAGS.train_keep_prob_fc,
                    g.keep_prob_cv1: FLAGS.train_keep_prob_cv
                }
                _, step = sess.run([g.train_step, g.global_step], feed)
                if (step + 1) % FLAGS.validation_steps == 0:
                    train_accuracy = sess.run([g.accuracy], feed)
                    val_accuracy = sess.run([g.accuracy], val_feed)
                    print("===step", step, "train_acc:", train_accuracy,
                          "val_acc", val_accuracy)

                if (step + 1) % FLAGS.save_steps == 0:
                    print("save checkpoint", step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)