def report_accuracy(decoded_list, test_targets):
    original_list = decode_sparse_tensor(test_targets)
    detected_list = decode_sparse_tensor(decoded_list)
    true_numer = 0
    # print(detected_list)
    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)",
              len(detected_list), " test and detect length desn't match")
        return
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):
        detect_number = detected_list[idx]
        if (len(number) == len(detect_number)):
            hit = True
            for idy, value in enumerate(number):
                detect_value = detect_number[idy]
                if (value != detect_value):
                    hit = False
                    break
            print(hit, number, "(", len(number), ") <-------> ", detect_number,
                  "(", len(detect_number), ")")
            if hit:
                true_numer = true_numer + 1
    accuraccy = true_numer * 1.0 / len(original_list)
    print("Test Accuracy:", accuraccy)
    return accuraccy
Ejemplo n.º 2
0
def report_accuracy(decoded_list, test_targets, test_names):
    original_list = decode_sparse_tensor(test_targets)
    detected_list = decode_sparse_tensor(decoded_list)
    names_list = test_names.tolist()
    total_ed = 0
    total_len = 0

    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)",
              len(detected_list), " test and detect length desn't match")
        return
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):

        detect_number = detected_list[idx]
        ed = editdistance.eval(number, detect_number)
        ln = len(number)
        edit_accuracy = (ln - ed) / ln
        print("Edit: ", ed, "Edit accuracy: ", edit_accuracy, number, "(",
              len(number), ") <-------> ", detect_number, "(",
              len(detect_number), ")")
        total_ed += ed
        total_len += ln

    accuraccy = (total_len - total_ed) / total_len
    print("Test Accuracy:", accuraccy)
    return accuraccy
def report_accuracy(decoded_list, test_targets):
    original_list = decode_sparse_tensor(test_targets)
    detected_list = decode_sparse_tensor(decoded_list)
    true_numer = 0
    # print(detected_list)
    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list),
              " test and detect length desn't match")
        return
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):
        detect_number = detected_list[idx]
        hit = (number == detect_number)
        print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")")
        if hit:
            true_numer = true_numer + 1
    print("Test Accuracy:", true_numer * 1.0 / len(original_list))
Ejemplo n.º 4
0
def report_accuracy(decoded_list, train_targets):
    original_list = decode_sparse_tensor(train_targets)#list,length:100,['6', '0', '6', '/','2', '血', '1', '9', ' ', '1', '2', ':', '2', '2', '~']
    detected_list = decode_sparse_tensor(decoded_list)#list,lenth:100,预测序列['0', '2', '0', '2', '0', '2', '0', '2', '2', '2', '0', '2', '0', '2', '0', '2', '0', '2', '2', '2', '2', '2', '0', '2', '2', '0', '2']
    true_numer = 0
    # print(detected_list)
    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list),
              " test and detect length desn't match")
        return
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):
        detect_number = detected_list[idx]
        hit = (number == detect_number)
        print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")")
        if hit:
            true_numer = true_numer + 1
    print("Test Accuracy:", true_numer * 1.0 / len(original_list))
Ejemplo n.º 5
0
def report_accuracy(decoded_list, test_targets):
    original_list = decode_sparse_tensor(test_targets)
    detected_list = decode_sparse_tensor(decoded_list)
    true_numer = 0
    # print(detected_list)
    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)",
              len(detected_list), " test and detect length desn't match")
        return
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):
        detect_number = detected_list[idx]
        hit = (number == detect_number)
        print(hit, number, "(", len(number), ") <-------> ", detect_number,
              "(", len(detect_number), ")")
        if hit:
            true_numer = true_numer + 1
    print("Test Accuracy:", true_numer * 1.0 / len(original_list))
Ejemplo n.º 6
0
def scan(file):
    img = Image.open(file.stream)
    image = np.array(img)
    image = utils.img2gray(image)
    utils.save(image * 255, os.path.join(curr_dir, "test", "p0.png"))
    # image = utils.clearImgGray(image)
    # utils.save(image * 255, os.path.join(curr_dir,"test","p1.png"))
    split_images = utils.splitImg(image)

    ocr_texts = []

    for i, split_image in enumerate(split_images):
        inv_image = utils.img2bwinv(split_image)
        inv_image = utils.clearImg(inv_image)
        image = 255. - split_image
        image = utils.dropZeroEdges(inv_image, image)
        image = utils.resize(image, ocr.image_height)
        image = image / 255.
        ocr_inputs = np.zeros([1, ocr.image_size, ocr.image_size])
        ocr_inputs[0, :] = utils.square_img(
            image, np.zeros([ocr.image_size, ocr.image_size]))

        ocr_seq_len = np.ones(1) * (ocr.image_size * ocr.image_size) // (
            ocr.POOL_SIZE * ocr.POOL_SIZE)

        start = time.time()
        p_net_g = session.run(net_g, {inputs: ocr_inputs})
        p_net_g = np.squeeze(p_net_g, axis=3)

        debug_net_g = np.copy(p_net_g)
        for j in range(1):
            _t_img = utils.unsquare_img(p_net_g[j], ocr.image_height)
            _t_img_bin = np.copy(_t_img)
            _t_img_bin[_t_img_bin <= 0.2] = 0
            _t_img = utils.dropZeroEdges(_t_img_bin, _t_img, min_rate=0.1)
            _t_img = utils.resize(_t_img, ocr.image_height)
            if _t_img.shape[0] * _t_img.shape[
                    1] <= ocr.image_size * ocr.image_size:
                p_net_g[j] = utils.square_img(
                    _t_img, np.zeros([ocr.image_size, ocr.image_size]),
                    ocr.image_height)

        _img = np.vstack((ocr_inputs[0], debug_net_g[0], p_net_g[0]))
        utils.save(_img * 255, os.path.join(curr_dir, "test", "%s.png" % i))

        decoded_list = session.run(res_decoded[0], {
            inputs: p_net_g,
            seq_len: ocr_seq_len
        })
        seconds = round(time.time() - start, 2)
        print("filished ocr %s , paid %s seconds" % (i, seconds))
        detected_list = utils.decode_sparse_tensor(decoded_list)
        for detect_number in detected_list:
            ocr_texts.append(ocr.list_to_chars(detect_number))

    return ocr_texts
Ejemplo n.º 7
0
def detect(test_inputs, test_targets, test_seq_len):
    logits, inputs, targets, seq_len, Wforward, Wbackward, b = model.get_train_model(
    )

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits,
                                                      seq_len,
                                                      merge_repeated=False)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, "models/ocr.model-0.64-17999")
        print("Model restored.")
        feed_dict = {inputs: test_inputs, seq_len: test_seq_len}
        dd = sess.run(decoded[0], feed_dict=feed_dict)
        original_list = decode_sparse_tensor(test_targets)
        detected_list = decode_sparse_tensor(dd)
        true_numer = 0
        # print(detected_list)
        if len(original_list) != (len(detected_list)):
            print("len(original_list)", len(original_list),
                  "len(detected_list)", len(detected_list),
                  " test and detect length desn't match")
            return
        print("T/F: original(length) <-------> detectcted(length)")
        for idx, number in enumerate(original_list):
            #if(idx==999):
            #		break
            detect_number = detected_list[idx]
            print(number, "(", len(number), ") <-------> ", detect_number, "(",
                  len(detect_number), ")")
            if (len(number) == len(detect_number)):
                hit = True
                for idy, value in enumerate(number):
                    detect_value = detect_number[idy]
                    if (value != detect_value):
                        hit = False
                        break
                if hit:
                    true_numer = true_numer + 1
        accuraccy = true_numer * 1.0 / len(original_list)
        #print("Test Accuracy:", accuraccy)
        return accuraccy
Ejemplo n.º 8
0
def report(train_labels, decoded_list):
    original_list = utils.decode_sparse_tensor(train_labels)
    detected_list = utils.decode_sparse_tensor(decoded_list)
    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)",
              len(detected_list), " test and detect length desn't match")
    acc = 0.
    for idx in range(min(len(original_list), len(detected_list))):
        number = original_list[idx]
        detect_number = detected_list[idx]
        hit = (number == detect_number)
        print("----------", hit, "------------")
        print(list_to_chars(number), "(", len(number), ")")
        print(list_to_chars(detect_number), "(", len(detect_number), ")")
        # 计算莱文斯坦比
        import Levenshtein
        acc += Levenshtein.ratio(list_to_chars(number),
                                 list_to_chars(detect_number))
    print("Test Accuracy:", acc / len(original_list))
Ejemplo n.º 9
0
def report_accuracy(decoded_list, test_targets, test_names):

    original_list = decode_sparse_tensor(test_targets)
    detected_list = decode_sparse_tensor(decoded_list)
    names_list = test_names.tolist()
    true_numer = 0
    total_ed = 0
    total_len = 0

    if len(original_list) != len(detected_list):
        print("len(original_list)", len(original_list), "len(detected_list)",
              len(detected_list), " test and detect length desn't match")
        return
    print("T/F: original(length) <-------> detectcted(length)")
    for idx, number in enumerate(original_list):

        detect_number = detected_list[idx]
        """if os.path.exists("output/"+names_list[idx] + ".out.txt"):
            append_write = 'a' # append if already exists
        else:
            append_write = 'w' # make a new file if not
        f = codecs.open("output/"+names_list[idx] + ".out.txt",append_write, 'utf-8')
        f.write("\nDetected: "+''.join(detect_number)+"\n"+"Original: ",''.join(number))
        f.close()"""

        ed = editdistance.eval(number, detect_number)
        ln = len(number)
        edit_accuracy = (ln - ed) / ln
        if (idx % 10 == 0):
            print("Edit: ", ed, "Edit accuracy: ",
                  edit_accuracy, "\n", ''.join(number).encode('utf-8'), "(",
                  len(number), ") <-------> ",
                  ''.join(detect_number).encode('utf-8'), "(",
                  len(detect_number), ")")
        total_ed += ed
        total_len += ln

    accuraccy = (total_len - total_ed) / total_len
    print("Test Accuracy:", accuraccy)
    return accuraccy
Ejemplo n.º 10
0
def detect(test_inputs, test_targets, test_seq_len):
    logits, inputs, targets, seq_len, W, b = model.get_train_model()

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, "models/ocr.model-0.95-94999")
        print("Model restored.")
        #feed_dict = {inputs: test_inputs, targets: test_targets, seq_len: test_seq_len}
        feed_dict = {inputs: test_inputs, seq_len: test_seq_len}
        dd = sess.run(decoded[0], feed_dict=feed_dict)
        #return decode_sparse_tensor(dd)
    	original_list = decode_sparse_tensor(test_targets)
    	detected_list = decode_sparse_tensor(dd)
	true_numer = 0
	# print(detected_list)
	if len(original_list) != len(detected_list):
		print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list),
				  " test and detect length desn't match")
		return
	print("T/F: original(length) <-------> detectcted(length)")
	for idx, number in enumerate(original_list):
	    detect_number = detected_list[idx]
            print(number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")")
            if(len(number) == len(detect_number)):
		hit = True
		for idy, value in  enumerate(number):
		    detect_value = detect_number[idy]
		    if(value != detect_value):
		        hit = False
		        break
		if hit:
	            true_numer = true_numer + 1
	accuraccy = true_numer * 1.0 / len(original_list)
	print("Test Accuracy:", accuraccy)
	return accuraccy
Ejemplo n.º 11
0
def report_accuracy(decoded_list, test_targets):
    # 报告字符识别的准确率,即n个字符中识别正确的字符个数占总字符个数的百分比
    original_list = decode_sparse_tensor(test_targets)  # 图片标注的字符列表
    detected_list = decode_sparse_tensor(decoded_list)  # 识别出的字符列表
    true_numer = 0
    # print(detected_list)
    if len(original_list) != len(detected_list):
        print(
            "len(original_list) 当前图片的标记字符长度", len(original_list),
            "len(detected_list) 从当前图片识别出的字符长度", len(detected_list),
            " test and detect length desn't match(从当前图片识别出的字符长度与标记的字符长度不一致)")
        return
    print(
        "T/F: original(length)  标记的字符长度 <-------> detectcted(length)  识别的字符长度")
    for idx, number in enumerate(original_list):
        detect_number = detected_list[idx]
        hit = (number == detect_number)
        print("是否识别当前图片:", hit, "  当前图片的标记字符为:", number, "(", len(number),
              "位 ) <-------> 从当前图片中识别出的字符为:", detect_number, "(",
              len(detect_number), "位 )")
        if hit:
            true_numer = true_numer + 1
    print("Test Accuracy(测试的准确率):",
          true_numer * 1.0 / len(original_list) * 1.0)
Ejemplo n.º 12
0
def detect(test_inputs, test_targets, test_seq_len):
    logits, inputs, targets, seq_len, W, b = model.get_train_model()

    decoded, log_prob = tf.contrib.ctc.ctc_beam_search_decoder(
        logits, seq_len, merge_repeated=False)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, "models/ocr.model-0.5-56499")
        print("Model restored.")
        #feed_dict = {inputs: test_inputs, targets: test_targets, seq_len: test_seq_len}
        feed_dict = {inputs: test_inputs, seq_len: test_seq_len}
        dd = sess.run(decoded[0], feed_dict=feed_dict)
        return decode_sparse_tensor(dd)
Ejemplo n.º 13
0
def detect():
    model_restore = "models/ocr.model-0.959379192885-161999"  #give the path of your final model file
    test_filename = "test.txt"  #txt file containing the paths of all the test images
    output_folder = "test_outputs3/"  #where the outputs will be stored
    factor1 = 35  #think of it as number of test_batches
    factor2 = 41  #think of it as the batch size

    ac = 0
    logits, inputs, targets, seq_len, W, b = model.get_train_model()

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits,
                                                      seq_len,
                                                      merge_repeated=False)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, model_restore)
        #saver.restore(sess, "models2/ocr.model-0.929263617018-35999")
        print("Model restored.")
        a = 0
        for x in range(0, factor1):
            test_names, test_inputs, test_targets, test_seq_len = utils.get_data_set(
                test_filename, x * factor2, (x + 1) * factor2)
            print test_inputs[0].shape

            feed_dict = {inputs: test_inputs, seq_len: test_seq_len}
            dd, lp = sess.run([decoded[0], log_prob], feed_dict=feed_dict)
            original_list = decode_sparse_tensor(test_targets)
            detected_list = decode_sparse_tensor(dd)
            names_list = test_names.tolist()
            print "lp", lp

            for x, fname_save in enumerate(names_list):
                result = detected_list[x]
                file = codecs.open(
                    output_folder + os.path.basename(fname_save) + ".rnn.txt",
                    "w", "utf-8")
                #file.write(''.join(result.tolist()))     if result is numpy
                file.write(''.join(result))
                file.close()

            if len(original_list) != (len(detected_list)):
                print("len(original_list)", len(original_list),
                      "len(detected_list)", len(detected_list),
                      " test and detect length desn't match")
                return
            print("T/F: original(length) <-------> detectcted(length)")
            total_ed = 0
            total_len = 0
            for idx, number in enumerate(original_list):

                detect_number = detected_list[idx]
                """if os.path.exists("output/"+names_list[idx] + ".out.txt"):
    			    append_write = 'a' # append if already exists
    			else:
    			    append_write = 'w' # make a new file if not
    			f = codecs.open("output/"+names_list[idx] + ".out.txt",append_write, 'utf-8')
    			f.write("\nDetected: "+''.join(detect_number)+"\n"+"Original: ",''.join(number))
    			f.close()"""

                ed = editdistance.eval(number, detect_number)

                ln = len(number)
                edit_accuracy = (ln - ed) / ln
                """if (idx % 10 == 0):
    			    print("Edit: ", ed, "Edit accuracy: ", edit_accuracy,"\n", ''.join(number).encode('utf-8'), "(", len(number), ") <-------> ", ''.join(detect_number).encode('utf-8'), "(", len(detect_number), ")")
                """
            total_ed += ed
            total_len += ln

            accuraccy = (total_len - total_ed) / total_len
            print("Test Accuracy:", accuraccy)
            ac += accuraccy
    return ac / factor1
Ejemplo n.º 14
0
def train():
    inputs, labels, global_step, lr, summary, \
        res_loss, res_optim, seq_len, res_acc, res_decoded, net_res = neural_networks()

    curr_dir = os.path.dirname(__file__)
    model_dir = os.path.join(curr_dir, MODEL_SAVE_NAME)
    if not os.path.exists(model_dir): os.mkdir(model_dir)
    model_R_dir = os.path.join(model_dir, "RL32")
    if not os.path.exists(model_R_dir): os.mkdir(model_R_dir)

    log_dir = os.path.join(model_dir, "logs")
    if not os.path.exists(log_dir): os.mkdir(log_dir)

    with tf.Session() as session:
        print("tf init")
        # init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        # session.run(init_op)
        session.run(tf.global_variables_initializer())

        print("tf check restore")
        # r_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='OCR'), sharded=True, max_to_keep=5)
        r_saver = tf.train.Saver(max_to_keep=5)

        for i in range(3):
            ckpt = tf.train.get_checkpoint_state(model_R_dir)
            if ckpt and ckpt.model_checkpoint_path:
                print("Restore Model OCR...")
                stem = os.path.basename(ckpt.model_checkpoint_path)
                restore_iter = int(stem.split('-')[-1])
                try:
                    r_saver.restore(session, ckpt.model_checkpoint_path)    
                except:
                    new_restore_iter = restore_iter - BATCHES
                    with open(os.path.join(model_R_dir,"checkpoint"),'w') as f:
                        f.write('model_checkpoint_path: "OCR.ckpt-%s"\n'%new_restore_iter)
                        f.write('all_model_checkpoint_paths: "OCR.ckpt-%s"\n'%new_restore_iter)
                    continue
                session.run(tf.assign(global_step, restore_iter))
                if restore_iter<10000:        
                    session.run(tf.assign(lr, 1e-4))
                elif restore_iter<50000:            
                    session.run(tf.assign(lr, 1e-5))
                else:
                    session.run(tf.assign(lr, 1e-6))
                print("Restored to %s."%restore_iter)
                break
            else:
                break
            print("restored fail, return")
            return

        print("tf create summary")
        train_writer = tf.summary.FileWriter(log_dir, session.graph)

        print("tf train")

        AllLosts={}
        accs = deque(maxlen=200)
        losts = deque(maxlen=200)
        while True:
            errR = 1
            batch_size = BATCH_SIZE
            for batch in range(BATCHES):
                start = time.time()    
                train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size)
                feed = {inputs: train_inputs, labels: train_labels, seq_len: train_seq_len} 

                feed_time = time.time() - start
                start = time.time()    

                # _res = session.run(net_res, feed)
                # print(train_inputs.shape)
                # print(_res.shape)
                # print(train_seq_len[0])

                errR, acc, _ , steps, res_lr = session.run([res_loss, res_acc, res_optim, global_step, lr], feed)
                font_length = int(train_info[0][-1])
                font_info = train_info[0][0]+"/"+train_info[0][1]+"/"+str(font_length)
                accs.append(acc)
                avg_acc = sum(accs)/len(accs)

                losts.append(errR)
                avg_losts = sum(losts)/len(losts)

                # errR = errR / font_length
                print("%s, %d time: %4.4fs / %4.4fs, acc: %.4f / %.4f, loss: %.4f / %.4f, lr:%.8f, info: %s " % \
                    (time.ctime(), steps, feed_time, time.time() - start, acc, avg_acc, errR, avg_losts, res_lr, font_info))

                # 如果当前lost低于平均lost,就多训练
                need_reset_global_step = False
                for _ in range(10):
                    if errR <=  avg_losts*2: break 
                    start = time.time()                
                    errR, acc, _, res_lr = session.run([res_loss, res_acc, res_optim, lr], feed)
                    accs.append(acc)
                    avg_acc = sum(accs)/len(accs)                  
                    print("%s, %d time: 0.0000s / %4.4fs, acc: %.4f, avg_acc: %.4f, loss: %.4f, avg_loss: %.4f, lr:%.8f, info: %s " % \
                        (time.ctime(), steps, time.time() - start, acc, avg_acc, errR, avg_losts, res_lr, font_info))   
                    need_reset_global_step = True
                    
                if need_reset_global_step:                     
                    session.run(tf.assign(global_step, steps))

                # if np.isnan(errR) or np.isinf(errR) :
                #     print("Error: cost is nan or inf")
                #     return

                for info in train_info:
                    key = ",".join(info)
                    if key in AllLosts:
                        AllLosts[key]=AllLosts[key]*0.99+acc*0.01
                    else:
                        AllLosts[key]=acc

                if acc/avg_acc<=0.2:
                    for i in range(batch_size): 
                        filename = "%s_%s_%s_%s_%s_%s_%s.png"%(acc, steps, i, \
                            train_info[i][0], train_info[i][1], train_info[i][2], train_info[i][3])
                        cv2.imwrite(os.path.join(curr_dir,"test",filename), train_inputs[i] * 255)                    
                # 报告
                if steps >0 and steps % REPORT_STEPS == 0:
                    train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size)   
           
                    decoded_list = session.run(res_decoded[0], {inputs: train_inputs, seq_len: train_seq_len}) 

                    for i in range(batch_size): 
                        cv2.imwrite(os.path.join(curr_dir,"test","%s_%s.png"%(steps,i)), train_inputs[i] * 255) 

                    original_list = utils.decode_sparse_tensor(train_labels)
                    detected_list = utils.decode_sparse_tensor(decoded_list)
                    if len(original_list) != len(detected_list):
                        print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list),
                            " test and detect length desn't match")
                    print("T/F: original(length) <-------> detectcted(length)")
                    acc = 0.
                    for idx in range(min(len(original_list),len(detected_list))):
                        number = original_list[idx]
                        detect_number = detected_list[idx]  
                        hit = (number == detect_number)
                        print("----------",hit,"------------")          
                        print(list_to_chars(number), "(", len(number), ")")
                        print(list_to_chars(detect_number), "(", len(detect_number), ")")
                        # 计算莱文斯坦比
                        import Levenshtein
                        acc += Levenshtein.ratio(list_to_chars(number),list_to_chars(detect_number))
                    print("Test Accuracy:", acc / len(original_list))
                    sorted_fonts = sorted(AllLosts.items(), key=operator.itemgetter(1), reverse=False)
                    for f in sorted_fonts[:20]:
                        print(f)

                    if avg_losts>100:        
                        session.run(tf.assign(lr, 1e-4))
                    elif avg_losts>10:            
                        session.run(tf.assign(lr, 5e-5))
                    else:
                        session.run(tf.assign(lr, 1e-5))
                                            
            # 如果当前 loss 为 nan,就先不要保存这个模型
            if np.isnan(errR) or np.isinf(errR):
                continue
            print("Save Model OCR ...")
            r_saver.save(session, os.path.join(model_R_dir, "OCR.ckpt"), global_step=steps)         
            logs = session.run(summary, feed)
            train_writer.add_summary(logs, steps)
Ejemplo n.º 15
0
def train():
    inputs, targets, labels, global_step, g_optim_init, d_loss, d_loss1, d_loss2, d_optim, \
        g_loss, g_mse_loss, g_res_loss, g_gan_loss, g_optim, net_g, \
        res_loss, res_optim, seq_len, res_acc, res_decoded = neural_networks()

    curr_dir = os.path.dirname(__file__)
    model_dir = os.path.join(curr_dir, MODEL_SAVE_NAME)
    if not os.path.exists(model_dir): os.mkdir(model_dir)
    model_R_dir = os.path.join(model_dir, "R")
    model_D_dir = os.path.join(model_dir, "D")
    model_G_dir = os.path.join(model_dir, "G")
    if not os.path.exists(model_R_dir): os.mkdir(model_R_dir)
    if not os.path.exists(model_D_dir): os.mkdir(model_D_dir)
    if not os.path.exists(model_G_dir): os.mkdir(model_G_dir)

    init = tf.global_variables_initializer()
    with tf.Session() as session:
        session.run(init)

        r_saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='RES'),
                                 max_to_keep=5)
        d_saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='SRGAN_d'),
                                 max_to_keep=5)
        g_saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='SRGAN_g'),
                                 max_to_keep=5)

        ckpt = tf.train.get_checkpoint_state(model_G_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Restore Model G...")
            g_saver.restore(session, ckpt.model_checkpoint_path)
        ckpt = tf.train.get_checkpoint_state(model_R_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Restore Model R...")
            r_saver.restore(session, ckpt.model_checkpoint_path)
        ckpt = tf.train.get_checkpoint_state(model_D_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Restore Model D...")
            d_saver.restore(session, ckpt.model_checkpoint_path)

        while True:
            for batch in range(BATCHES):
                for i in range(10):
                    train_inputs, train_targets, train_labels, train_seq_len = get_next_batch(
                        8)
                    feed = {
                        inputs: train_inputs,
                        targets: train_targets,
                        labels: train_labels,
                        seq_len: train_seq_len
                    }

                    # train res
                    start = time.time()
                    errM, acc, _, steps = session.run(
                        [res_loss, res_acc, res_optim, global_step], feed)
                    print("%d time: %4.4fs, res_loss: %.8f, res_acc: %.8f " %
                          (steps, time.time() - start, errM, acc))
                    if np.isnan(errM) or np.isinf(errM):
                        print("Error: cost is nan or inf")
                        return

                    if i > 0: continue

                    train_inputs, train_targets, train_labels, train_seq_len = get_next_batch(
                        4)
                    feed = {
                        inputs: train_inputs,
                        targets: train_targets,
                        labels: train_labels,
                        seq_len: train_seq_len
                    }

                    # train G
                    start = time.time()
                    errM, _, steps = session.run(
                        [g_mse_loss, g_optim_init, global_step], feed)
                    print("%d time: %4.4fs, g_mse_loss: %.8f " %
                          (steps, time.time() - start, errM))
                    if np.isnan(errM) or np.isinf(errM):
                        print("Error: cost is nan or inf")
                        return

                    # train GAN (SRGAN)

                    start = time.time()
                    ## update G
                    errG, errM, errV, errA, _, steps = session.run([
                        g_loss, g_mse_loss, g_res_loss, g_gan_loss, g_optim,
                        global_step
                    ], feed)
                    print(
                        "%d time: %4.4fs, g_loss: %.8f (mse: %.6f res: %.6f adv: %.6f)"
                        % (steps, time.time() - start, errG, errM, errV, errA))
                    if np.isnan(errG) or np.isinf(errG) or np.isnan(
                            errA) or np.isinf(errA):
                        print("Error: cost is nan or inf")
                        return

                    start = time.time()
                    ## update D
                    errD, errD1, errD2, _, steps = session.run(
                        [d_loss, d_loss1, d_loss2, d_optim, global_step], feed)
                    print(
                        "%d time: %4.4fs, d_loss: %.8f (d_loss1: %.6f  d_loss2: %.6f)"
                        % (steps, time.time() - start, errD, errD1, errD2))
                    if np.isnan(errD) or np.isinf(errD):
                        print("Error: cost is nan or inf")
                        return

                if steps > 0 and steps % REPORT_STEPS < 13:
                    train_inputs, train_targets, train_labels, train_seq_len = get_next_batch(
                        4)
                    feed = {inputs: train_inputs, targets: train_targets}
                    b_predictions = session.run(net_g, feed)
                    for i in range(4):
                        _predictions = np.reshape(b_predictions[i],
                                                  train_targets[i].shape)
                        _pred = np.transpose(_predictions)
                        _img = np.vstack((np.transpose(train_inputs[i]), _pred,
                                          np.transpose(train_targets[i])))
                        cv2.imwrite(
                            os.path.join(curr_dir, "test",
                                         "%s_%s.png" % (steps, i)), _img * 255)

                    feed = {
                        inputs: train_inputs,
                        targets: train_targets,
                        labels: train_labels,
                        seq_len: train_seq_len
                    }
                    decoded_list = session.run(res_decoded[0], feed)
                    original_list = utils.decode_sparse_tensor(train_labels)
                    detected_list = utils.decode_sparse_tensor(decoded_list)
                    if len(original_list) != len(detected_list):
                        print("len(original_list)", len(original_list),
                              "len(detected_list)", len(detected_list),
                              " test and detect length desn't match")
                    print("T/F: original(length) <-------> detectcted(length)")
                    acc = 0.
                    for idx in range(
                            min(len(original_list), len(detected_list))):
                        number = original_list[idx]
                        detect_number = detected_list[idx]
                        hit = (number == detect_number)
                        print("%6s" % hit, list_to_chars(number), "(",
                              len(number), ")")
                        print("%6s" % "", list_to_chars(detect_number), "(",
                              len(detect_number), ")")
                        # 计算莱文斯坦比
                        import Levenshtein
                        acc += Levenshtein.ratio(list_to_chars(number),
                                                 list_to_chars(detect_number))
                    print("Test Accuracy:", acc / len(original_list))

            print("Save Model R ...")
            r_saver.save(session,
                         os.path.join(model_R_dir, "R.ckpt"),
                         global_step=steps)
            print("Save Model D ...")
            d_saver.save(session,
                         os.path.join(model_D_dir, "D.ckpt"),
                         global_step=steps)
            print("Save Model G ...")
            g_saver.save(session,
                         os.path.join(model_G_dir, "G.ckpt"),
                         global_step=steps)
Ejemplo n.º 16
0
def train():
    inputs, labels, global_step, \
        res_loss, res_optim, seq_len, res_acc, res_decoded, \
        net_g = neural_networks()

    curr_dir = os.path.dirname(__file__)
    model_dir = os.path.join(curr_dir, MODEL_SAVE_NAME)
    if not os.path.exists(model_dir): os.mkdir(model_dir)
    model_G_dir = os.path.join(model_dir, "TG")
    model_R_dir = os.path.join(model_dir, "R16")

    if not os.path.exists(model_R_dir): os.mkdir(model_R_dir)
    if not os.path.exists(model_G_dir): os.mkdir(model_G_dir)

    init = tf.global_variables_initializer()
    with tf.Session() as session:
        session.run(init)

        r_saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='RES'),
                                 sharded=True)
        g_saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='TRIM_G'),
                                 sharded=False)

        ckpt = tf.train.get_checkpoint_state(model_G_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Restore Model G...")
            g_saver.restore(session, ckpt.model_checkpoint_path)

        ckpt = tf.train.get_checkpoint_state(model_R_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Restore Model R...")
            r_saver.restore(session, ckpt.model_checkpoint_path)

        AllLosts = {}
        while True:
            errA = errD1 = errD2 = 1
            batch_size = 4
            for batch in range(BATCHES):
                if len(AllLosts) > 10 and random.random() > 0.7:
                    sorted_font = sorted(AllLosts.items(),
                                         key=operator.itemgetter(1),
                                         reverse=True)
                    font_info = sorted_font[random.randint(0, 10)]
                    font_info = font_info[0].split(",")
                    train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size, False, \
                        font_info[0], int(font_info[1]), int(font_info[2]), int(font_info[3]))
                else:
                    # train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size, False, _font_size=36)
                    train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(
                        batch_size)
                # feed = {inputs: train_inputs, labels: train_labels, seq_len: train_seq_len}
                start = time.time()

                p_net_g = session.run(net_g, {inputs: train_inputs})

                p_net_g = np.squeeze(p_net_g, axis=3)
                for i in range(batch_size):
                    _t_img = utils.unsquare_img(p_net_g[i], image_height)
                    _t_img = utils.cvTrimImage(_t_img)
                    _t_img[_t_img < 0] = 0
                    _t_img = utils.resize(_t_img, image_height)
                    if _t_img.shape[0] * _t_img.shape[
                            1] <= image_size * image_size:
                        p_net_g[i] = utils.square_img(
                            _t_img, np.zeros([image_size, image_size]),
                            image_height)

                feed = {
                    inputs: p_net_g,
                    labels: train_labels,
                    seq_len: train_seq_len
                }

                errR, acc, _, steps = session.run(
                    [res_loss, res_acc, res_optim, global_step], feed)
                font_info = train_info[0][0] + "/" + train_info[0][
                    1] + " " + train_info[1][0] + "/" + train_info[1][1]
                print(
                    "%d time: %4.4fs, res_acc: %.4f, res_loss: %.4f, info: %s "
                    % (steps, time.time() - start, acc, errR, font_info))
                if np.isnan(errR) or np.isinf(errR):
                    print("Error: cost is nan or inf")
                    return

                # 如果正确率低于90%,保存出来
                if acc < 0.9:
                    for i in range(batch_size):
                        _img = np.vstack(
                            (train_inputs[i] * 255, p_net_g[i] * 255))
                        cv2.imwrite(
                            os.path.join(curr_dir, "test",
                                         "E%s_%s_%s.png" % (acc, steps, i)),
                            _img)

                for info in train_info:
                    key = ",".join(info)
                    if key in AllLosts:
                        AllLosts[key] = AllLosts[key] * 0.95 + errR * 0.05
                    else:
                        AllLosts[key] = errR

                # 报告
                if steps > 0 and steps % REPORT_STEPS == 0:
                    train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(
                        batch_size)
                    p_net_g = session.run(net_g, {inputs: train_inputs})
                    p_net_g = np.squeeze(p_net_g, axis=3)

                    for i in range(batch_size):
                        _t_img = utils.unsquare_img(p_net_g[i], image_height)
                        _t_img_bin = np.copy(_t_img)
                        _t_img_bin[_t_img_bin <= 0.3] = 0
                        _t_img = utils.dropZeroEdges(_t_img_bin,
                                                     _t_img,
                                                     min_rate=0.1)
                        _t_img = utils.resize(_t_img, image_height)
                        if _t_img.shape[0] * _t_img.shape[
                                1] <= image_size * image_size:
                            p_net_g[i] = utils.square_img(
                                _t_img, np.zeros([image_size, image_size]),
                                image_height)

                    decoded_list = session.run(res_decoded[0], {
                        inputs: p_net_g,
                        seq_len: train_seq_len
                    })

                    for i in range(batch_size):
                        _img = np.vstack((train_inputs[i], p_net_g[i]))
                        cv2.imwrite(
                            os.path.join(curr_dir, "test",
                                         "%s_%s.png" % (steps, i)), _img * 255)

                    original_list = utils.decode_sparse_tensor(train_labels)
                    detected_list = utils.decode_sparse_tensor(decoded_list)
                    if len(original_list) != len(detected_list):
                        print("len(original_list)", len(original_list),
                              "len(detected_list)", len(detected_list),
                              " test and detect length desn't match")
                    print("T/F: original(length) <-------> detectcted(length)")
                    acc = 0.
                    for idx in range(
                            min(len(original_list), len(detected_list))):
                        number = original_list[idx]
                        detect_number = detected_list[idx]
                        hit = (number == detect_number)
                        print("%6s" % hit, list_to_chars(number), "(",
                              len(number), ")")
                        print("%6s" % "", list_to_chars(detect_number), "(",
                              len(detect_number), ")")
                        # 计算莱文斯坦比
                        import Levenshtein
                        acc += Levenshtein.ratio(list_to_chars(number),
                                                 list_to_chars(detect_number))
                    print("Test Accuracy:", acc / len(original_list))
                    sorted_fonts = sorted(AllLosts.items(),
                                          key=operator.itemgetter(1),
                                          reverse=True)
                    for f in sorted_fonts[:20]:
                        print(f)
            print("Save Model R ...")
            r_saver.save(session,
                         os.path.join(model_R_dir, "R.ckpt"),
                         global_step=steps)
            try:
                ckpt = tf.train.get_checkpoint_state(model_G_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    print("Restore Model G...")
                    g_saver.restore(session, ckpt.model_checkpoint_path)
            except:
                pass