def report_accuracy(decoded_list, test_targets): original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(decoded_list) true_numer = 0 # print(detected_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] if (len(number) == len(detect_number)): hit = True for idy, value in enumerate(number): detect_value = detect_number[idy] if (value != detect_value): hit = False break print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") if hit: true_numer = true_numer + 1 accuraccy = true_numer * 1.0 / len(original_list) print("Test Accuracy:", accuraccy) return accuraccy
def report_accuracy(decoded_list, test_targets, test_names): original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(decoded_list) names_list = test_names.tolist() total_ed = 0 total_len = 0 if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] ed = editdistance.eval(number, detect_number) ln = len(number) edit_accuracy = (ln - ed) / ln print("Edit: ", ed, "Edit accuracy: ", edit_accuracy, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") total_ed += ed total_len += ln accuraccy = (total_len - total_ed) / total_len print("Test Accuracy:", accuraccy) return accuraccy
def report_accuracy(decoded_list, test_targets): original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(decoded_list) true_numer = 0 # print(detected_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] hit = (number == detect_number) print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") if hit: true_numer = true_numer + 1 print("Test Accuracy:", true_numer * 1.0 / len(original_list))
def report_accuracy(decoded_list, train_targets): original_list = decode_sparse_tensor(train_targets)#list,length:100,['6', '0', '6', '/','2', '血', '1', '9', ' ', '1', '2', ':', '2', '2', '~'] detected_list = decode_sparse_tensor(decoded_list)#list,lenth:100,预测序列['0', '2', '0', '2', '0', '2', '0', '2', '2', '2', '0', '2', '0', '2', '0', '2', '0', '2', '2', '2', '2', '2', '0', '2', '2', '0', '2'] true_numer = 0 # print(detected_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] hit = (number == detect_number) print(hit, number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") if hit: true_numer = true_numer + 1 print("Test Accuracy:", true_numer * 1.0 / len(original_list))
def scan(file): img = Image.open(file.stream) image = np.array(img) image = utils.img2gray(image) utils.save(image * 255, os.path.join(curr_dir, "test", "p0.png")) # image = utils.clearImgGray(image) # utils.save(image * 255, os.path.join(curr_dir,"test","p1.png")) split_images = utils.splitImg(image) ocr_texts = [] for i, split_image in enumerate(split_images): inv_image = utils.img2bwinv(split_image) inv_image = utils.clearImg(inv_image) image = 255. - split_image image = utils.dropZeroEdges(inv_image, image) image = utils.resize(image, ocr.image_height) image = image / 255. ocr_inputs = np.zeros([1, ocr.image_size, ocr.image_size]) ocr_inputs[0, :] = utils.square_img( image, np.zeros([ocr.image_size, ocr.image_size])) ocr_seq_len = np.ones(1) * (ocr.image_size * ocr.image_size) // ( ocr.POOL_SIZE * ocr.POOL_SIZE) start = time.time() p_net_g = session.run(net_g, {inputs: ocr_inputs}) p_net_g = np.squeeze(p_net_g, axis=3) debug_net_g = np.copy(p_net_g) for j in range(1): _t_img = utils.unsquare_img(p_net_g[j], ocr.image_height) _t_img_bin = np.copy(_t_img) _t_img_bin[_t_img_bin <= 0.2] = 0 _t_img = utils.dropZeroEdges(_t_img_bin, _t_img, min_rate=0.1) _t_img = utils.resize(_t_img, ocr.image_height) if _t_img.shape[0] * _t_img.shape[ 1] <= ocr.image_size * ocr.image_size: p_net_g[j] = utils.square_img( _t_img, np.zeros([ocr.image_size, ocr.image_size]), ocr.image_height) _img = np.vstack((ocr_inputs[0], debug_net_g[0], p_net_g[0])) utils.save(_img * 255, os.path.join(curr_dir, "test", "%s.png" % i)) decoded_list = session.run(res_decoded[0], { inputs: p_net_g, seq_len: ocr_seq_len }) seconds = round(time.time() - start, 2) print("filished ocr %s , paid %s seconds" % (i, seconds)) detected_list = utils.decode_sparse_tensor(decoded_list) for detect_number in detected_list: ocr_texts.append(ocr.list_to_chars(detect_number)) return ocr_texts
def detect(test_inputs, test_targets, test_seq_len): logits, inputs, targets, seq_len, Wforward, Wbackward, b = model.get_train_model( ) decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False) saver = tf.train.Saver() with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "models/ocr.model-0.64-17999") print("Model restored.") feed_dict = {inputs: test_inputs, seq_len: test_seq_len} dd = sess.run(decoded[0], feed_dict=feed_dict) original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(dd) true_numer = 0 # print(detected_list) if len(original_list) != (len(detected_list)): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): #if(idx==999): # break detect_number = detected_list[idx] print(number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") if (len(number) == len(detect_number)): hit = True for idy, value in enumerate(number): detect_value = detect_number[idy] if (value != detect_value): hit = False break if hit: true_numer = true_numer + 1 accuraccy = true_numer * 1.0 / len(original_list) #print("Test Accuracy:", accuraccy) return accuraccy
def report(train_labels, decoded_list): original_list = utils.decode_sparse_tensor(train_labels) detected_list = utils.decode_sparse_tensor(decoded_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") acc = 0. for idx in range(min(len(original_list), len(detected_list))): number = original_list[idx] detect_number = detected_list[idx] hit = (number == detect_number) print("----------", hit, "------------") print(list_to_chars(number), "(", len(number), ")") print(list_to_chars(detect_number), "(", len(detect_number), ")") # 计算莱文斯坦比 import Levenshtein acc += Levenshtein.ratio(list_to_chars(number), list_to_chars(detect_number)) print("Test Accuracy:", acc / len(original_list))
def report_accuracy(decoded_list, test_targets, test_names): original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(decoded_list) names_list = test_names.tolist() true_numer = 0 total_ed = 0 total_len = 0 if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] """if os.path.exists("output/"+names_list[idx] + ".out.txt"): append_write = 'a' # append if already exists else: append_write = 'w' # make a new file if not f = codecs.open("output/"+names_list[idx] + ".out.txt",append_write, 'utf-8') f.write("\nDetected: "+''.join(detect_number)+"\n"+"Original: ",''.join(number)) f.close()""" ed = editdistance.eval(number, detect_number) ln = len(number) edit_accuracy = (ln - ed) / ln if (idx % 10 == 0): print("Edit: ", ed, "Edit accuracy: ", edit_accuracy, "\n", ''.join(number).encode('utf-8'), "(", len(number), ") <-------> ", ''.join(detect_number).encode('utf-8'), "(", len(detect_number), ")") total_ed += ed total_len += ln accuraccy = (total_len - total_ed) / total_len print("Test Accuracy:", accuraccy) return accuraccy
def detect(test_inputs, test_targets, test_seq_len): logits, inputs, targets, seq_len, W, b = model.get_train_model() decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False) saver = tf.train.Saver() with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "models/ocr.model-0.95-94999") print("Model restored.") #feed_dict = {inputs: test_inputs, targets: test_targets, seq_len: test_seq_len} feed_dict = {inputs: test_inputs, seq_len: test_seq_len} dd = sess.run(decoded[0], feed_dict=feed_dict) #return decode_sparse_tensor(dd) original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(dd) true_numer = 0 # print(detected_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") for idx, number in enumerate(original_list): detect_number = detected_list[idx] print(number, "(", len(number), ") <-------> ", detect_number, "(", len(detect_number), ")") if(len(number) == len(detect_number)): hit = True for idy, value in enumerate(number): detect_value = detect_number[idy] if(value != detect_value): hit = False break if hit: true_numer = true_numer + 1 accuraccy = true_numer * 1.0 / len(original_list) print("Test Accuracy:", accuraccy) return accuraccy
def report_accuracy(decoded_list, test_targets): # 报告字符识别的准确率,即n个字符中识别正确的字符个数占总字符个数的百分比 original_list = decode_sparse_tensor(test_targets) # 图片标注的字符列表 detected_list = decode_sparse_tensor(decoded_list) # 识别出的字符列表 true_numer = 0 # print(detected_list) if len(original_list) != len(detected_list): print( "len(original_list) 当前图片的标记字符长度", len(original_list), "len(detected_list) 从当前图片识别出的字符长度", len(detected_list), " test and detect length desn't match(从当前图片识别出的字符长度与标记的字符长度不一致)") return print( "T/F: original(length) 标记的字符长度 <-------> detectcted(length) 识别的字符长度") for idx, number in enumerate(original_list): detect_number = detected_list[idx] hit = (number == detect_number) print("是否识别当前图片:", hit, " 当前图片的标记字符为:", number, "(", len(number), "位 ) <-------> 从当前图片中识别出的字符为:", detect_number, "(", len(detect_number), "位 )") if hit: true_numer = true_numer + 1 print("Test Accuracy(测试的准确率):", true_numer * 1.0 / len(original_list) * 1.0)
def detect(test_inputs, test_targets, test_seq_len): logits, inputs, targets, seq_len, W, b = model.get_train_model() decoded, log_prob = tf.contrib.ctc.ctc_beam_search_decoder( logits, seq_len, merge_repeated=False) saver = tf.train.Saver() with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "models/ocr.model-0.5-56499") print("Model restored.") #feed_dict = {inputs: test_inputs, targets: test_targets, seq_len: test_seq_len} feed_dict = {inputs: test_inputs, seq_len: test_seq_len} dd = sess.run(decoded[0], feed_dict=feed_dict) return decode_sparse_tensor(dd)
def detect(): model_restore = "models/ocr.model-0.959379192885-161999" #give the path of your final model file test_filename = "test.txt" #txt file containing the paths of all the test images output_folder = "test_outputs3/" #where the outputs will be stored factor1 = 35 #think of it as number of test_batches factor2 = 41 #think of it as the batch size ac = 0 logits, inputs, targets, seq_len, W, b = model.get_train_model() decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False) saver = tf.train.Saver() with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, model_restore) #saver.restore(sess, "models2/ocr.model-0.929263617018-35999") print("Model restored.") a = 0 for x in range(0, factor1): test_names, test_inputs, test_targets, test_seq_len = utils.get_data_set( test_filename, x * factor2, (x + 1) * factor2) print test_inputs[0].shape feed_dict = {inputs: test_inputs, seq_len: test_seq_len} dd, lp = sess.run([decoded[0], log_prob], feed_dict=feed_dict) original_list = decode_sparse_tensor(test_targets) detected_list = decode_sparse_tensor(dd) names_list = test_names.tolist() print "lp", lp for x, fname_save in enumerate(names_list): result = detected_list[x] file = codecs.open( output_folder + os.path.basename(fname_save) + ".rnn.txt", "w", "utf-8") #file.write(''.join(result.tolist())) if result is numpy file.write(''.join(result)) file.close() if len(original_list) != (len(detected_list)): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") return print("T/F: original(length) <-------> detectcted(length)") total_ed = 0 total_len = 0 for idx, number in enumerate(original_list): detect_number = detected_list[idx] """if os.path.exists("output/"+names_list[idx] + ".out.txt"): append_write = 'a' # append if already exists else: append_write = 'w' # make a new file if not f = codecs.open("output/"+names_list[idx] + ".out.txt",append_write, 'utf-8') f.write("\nDetected: "+''.join(detect_number)+"\n"+"Original: ",''.join(number)) f.close()""" ed = editdistance.eval(number, detect_number) ln = len(number) edit_accuracy = (ln - ed) / ln """if (idx % 10 == 0): print("Edit: ", ed, "Edit accuracy: ", edit_accuracy,"\n", ''.join(number).encode('utf-8'), "(", len(number), ") <-------> ", ''.join(detect_number).encode('utf-8'), "(", len(detect_number), ")") """ total_ed += ed total_len += ln accuraccy = (total_len - total_ed) / total_len print("Test Accuracy:", accuraccy) ac += accuraccy return ac / factor1
def train(): inputs, labels, global_step, lr, summary, \ res_loss, res_optim, seq_len, res_acc, res_decoded, net_res = neural_networks() curr_dir = os.path.dirname(__file__) model_dir = os.path.join(curr_dir, MODEL_SAVE_NAME) if not os.path.exists(model_dir): os.mkdir(model_dir) model_R_dir = os.path.join(model_dir, "RL32") if not os.path.exists(model_R_dir): os.mkdir(model_R_dir) log_dir = os.path.join(model_dir, "logs") if not os.path.exists(log_dir): os.mkdir(log_dir) with tf.Session() as session: print("tf init") # init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # session.run(init_op) session.run(tf.global_variables_initializer()) print("tf check restore") # r_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='OCR'), sharded=True, max_to_keep=5) r_saver = tf.train.Saver(max_to_keep=5) for i in range(3): ckpt = tf.train.get_checkpoint_state(model_R_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model OCR...") stem = os.path.basename(ckpt.model_checkpoint_path) restore_iter = int(stem.split('-')[-1]) try: r_saver.restore(session, ckpt.model_checkpoint_path) except: new_restore_iter = restore_iter - BATCHES with open(os.path.join(model_R_dir,"checkpoint"),'w') as f: f.write('model_checkpoint_path: "OCR.ckpt-%s"\n'%new_restore_iter) f.write('all_model_checkpoint_paths: "OCR.ckpt-%s"\n'%new_restore_iter) continue session.run(tf.assign(global_step, restore_iter)) if restore_iter<10000: session.run(tf.assign(lr, 1e-4)) elif restore_iter<50000: session.run(tf.assign(lr, 1e-5)) else: session.run(tf.assign(lr, 1e-6)) print("Restored to %s."%restore_iter) break else: break print("restored fail, return") return print("tf create summary") train_writer = tf.summary.FileWriter(log_dir, session.graph) print("tf train") AllLosts={} accs = deque(maxlen=200) losts = deque(maxlen=200) while True: errR = 1 batch_size = BATCH_SIZE for batch in range(BATCHES): start = time.time() train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size) feed = {inputs: train_inputs, labels: train_labels, seq_len: train_seq_len} feed_time = time.time() - start start = time.time() # _res = session.run(net_res, feed) # print(train_inputs.shape) # print(_res.shape) # print(train_seq_len[0]) errR, acc, _ , steps, res_lr = session.run([res_loss, res_acc, res_optim, global_step, lr], feed) font_length = int(train_info[0][-1]) font_info = train_info[0][0]+"/"+train_info[0][1]+"/"+str(font_length) accs.append(acc) avg_acc = sum(accs)/len(accs) losts.append(errR) avg_losts = sum(losts)/len(losts) # errR = errR / font_length print("%s, %d time: %4.4fs / %4.4fs, acc: %.4f / %.4f, loss: %.4f / %.4f, lr:%.8f, info: %s " % \ (time.ctime(), steps, feed_time, time.time() - start, acc, avg_acc, errR, avg_losts, res_lr, font_info)) # 如果当前lost低于平均lost,就多训练 need_reset_global_step = False for _ in range(10): if errR <= avg_losts*2: break start = time.time() errR, acc, _, res_lr = session.run([res_loss, res_acc, res_optim, lr], feed) accs.append(acc) avg_acc = sum(accs)/len(accs) print("%s, %d time: 0.0000s / %4.4fs, acc: %.4f, avg_acc: %.4f, loss: %.4f, avg_loss: %.4f, lr:%.8f, info: %s " % \ (time.ctime(), steps, time.time() - start, acc, avg_acc, errR, avg_losts, res_lr, font_info)) need_reset_global_step = True if need_reset_global_step: session.run(tf.assign(global_step, steps)) # if np.isnan(errR) or np.isinf(errR) : # print("Error: cost is nan or inf") # return for info in train_info: key = ",".join(info) if key in AllLosts: AllLosts[key]=AllLosts[key]*0.99+acc*0.01 else: AllLosts[key]=acc if acc/avg_acc<=0.2: for i in range(batch_size): filename = "%s_%s_%s_%s_%s_%s_%s.png"%(acc, steps, i, \ train_info[i][0], train_info[i][1], train_info[i][2], train_info[i][3]) cv2.imwrite(os.path.join(curr_dir,"test",filename), train_inputs[i] * 255) # 报告 if steps >0 and steps % REPORT_STEPS == 0: train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size) decoded_list = session.run(res_decoded[0], {inputs: train_inputs, seq_len: train_seq_len}) for i in range(batch_size): cv2.imwrite(os.path.join(curr_dir,"test","%s_%s.png"%(steps,i)), train_inputs[i] * 255) original_list = utils.decode_sparse_tensor(train_labels) detected_list = utils.decode_sparse_tensor(decoded_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") print("T/F: original(length) <-------> detectcted(length)") acc = 0. for idx in range(min(len(original_list),len(detected_list))): number = original_list[idx] detect_number = detected_list[idx] hit = (number == detect_number) print("----------",hit,"------------") print(list_to_chars(number), "(", len(number), ")") print(list_to_chars(detect_number), "(", len(detect_number), ")") # 计算莱文斯坦比 import Levenshtein acc += Levenshtein.ratio(list_to_chars(number),list_to_chars(detect_number)) print("Test Accuracy:", acc / len(original_list)) sorted_fonts = sorted(AllLosts.items(), key=operator.itemgetter(1), reverse=False) for f in sorted_fonts[:20]: print(f) if avg_losts>100: session.run(tf.assign(lr, 1e-4)) elif avg_losts>10: session.run(tf.assign(lr, 5e-5)) else: session.run(tf.assign(lr, 1e-5)) # 如果当前 loss 为 nan,就先不要保存这个模型 if np.isnan(errR) or np.isinf(errR): continue print("Save Model OCR ...") r_saver.save(session, os.path.join(model_R_dir, "OCR.ckpt"), global_step=steps) logs = session.run(summary, feed) train_writer.add_summary(logs, steps)
def train(): inputs, targets, labels, global_step, g_optim_init, d_loss, d_loss1, d_loss2, d_optim, \ g_loss, g_mse_loss, g_res_loss, g_gan_loss, g_optim, net_g, \ res_loss, res_optim, seq_len, res_acc, res_decoded = neural_networks() curr_dir = os.path.dirname(__file__) model_dir = os.path.join(curr_dir, MODEL_SAVE_NAME) if not os.path.exists(model_dir): os.mkdir(model_dir) model_R_dir = os.path.join(model_dir, "R") model_D_dir = os.path.join(model_dir, "D") model_G_dir = os.path.join(model_dir, "G") if not os.path.exists(model_R_dir): os.mkdir(model_R_dir) if not os.path.exists(model_D_dir): os.mkdir(model_D_dir) if not os.path.exists(model_G_dir): os.mkdir(model_G_dir) init = tf.global_variables_initializer() with tf.Session() as session: session.run(init) r_saver = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='RES'), max_to_keep=5) d_saver = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='SRGAN_d'), max_to_keep=5) g_saver = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='SRGAN_g'), max_to_keep=5) ckpt = tf.train.get_checkpoint_state(model_G_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model G...") g_saver.restore(session, ckpt.model_checkpoint_path) ckpt = tf.train.get_checkpoint_state(model_R_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model R...") r_saver.restore(session, ckpt.model_checkpoint_path) ckpt = tf.train.get_checkpoint_state(model_D_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model D...") d_saver.restore(session, ckpt.model_checkpoint_path) while True: for batch in range(BATCHES): for i in range(10): train_inputs, train_targets, train_labels, train_seq_len = get_next_batch( 8) feed = { inputs: train_inputs, targets: train_targets, labels: train_labels, seq_len: train_seq_len } # train res start = time.time() errM, acc, _, steps = session.run( [res_loss, res_acc, res_optim, global_step], feed) print("%d time: %4.4fs, res_loss: %.8f, res_acc: %.8f " % (steps, time.time() - start, errM, acc)) if np.isnan(errM) or np.isinf(errM): print("Error: cost is nan or inf") return if i > 0: continue train_inputs, train_targets, train_labels, train_seq_len = get_next_batch( 4) feed = { inputs: train_inputs, targets: train_targets, labels: train_labels, seq_len: train_seq_len } # train G start = time.time() errM, _, steps = session.run( [g_mse_loss, g_optim_init, global_step], feed) print("%d time: %4.4fs, g_mse_loss: %.8f " % (steps, time.time() - start, errM)) if np.isnan(errM) or np.isinf(errM): print("Error: cost is nan or inf") return # train GAN (SRGAN) start = time.time() ## update G errG, errM, errV, errA, _, steps = session.run([ g_loss, g_mse_loss, g_res_loss, g_gan_loss, g_optim, global_step ], feed) print( "%d time: %4.4fs, g_loss: %.8f (mse: %.6f res: %.6f adv: %.6f)" % (steps, time.time() - start, errG, errM, errV, errA)) if np.isnan(errG) or np.isinf(errG) or np.isnan( errA) or np.isinf(errA): print("Error: cost is nan or inf") return start = time.time() ## update D errD, errD1, errD2, _, steps = session.run( [d_loss, d_loss1, d_loss2, d_optim, global_step], feed) print( "%d time: %4.4fs, d_loss: %.8f (d_loss1: %.6f d_loss2: %.6f)" % (steps, time.time() - start, errD, errD1, errD2)) if np.isnan(errD) or np.isinf(errD): print("Error: cost is nan or inf") return if steps > 0 and steps % REPORT_STEPS < 13: train_inputs, train_targets, train_labels, train_seq_len = get_next_batch( 4) feed = {inputs: train_inputs, targets: train_targets} b_predictions = session.run(net_g, feed) for i in range(4): _predictions = np.reshape(b_predictions[i], train_targets[i].shape) _pred = np.transpose(_predictions) _img = np.vstack((np.transpose(train_inputs[i]), _pred, np.transpose(train_targets[i]))) cv2.imwrite( os.path.join(curr_dir, "test", "%s_%s.png" % (steps, i)), _img * 255) feed = { inputs: train_inputs, targets: train_targets, labels: train_labels, seq_len: train_seq_len } decoded_list = session.run(res_decoded[0], feed) original_list = utils.decode_sparse_tensor(train_labels) detected_list = utils.decode_sparse_tensor(decoded_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") print("T/F: original(length) <-------> detectcted(length)") acc = 0. for idx in range( min(len(original_list), len(detected_list))): number = original_list[idx] detect_number = detected_list[idx] hit = (number == detect_number) print("%6s" % hit, list_to_chars(number), "(", len(number), ")") print("%6s" % "", list_to_chars(detect_number), "(", len(detect_number), ")") # 计算莱文斯坦比 import Levenshtein acc += Levenshtein.ratio(list_to_chars(number), list_to_chars(detect_number)) print("Test Accuracy:", acc / len(original_list)) print("Save Model R ...") r_saver.save(session, os.path.join(model_R_dir, "R.ckpt"), global_step=steps) print("Save Model D ...") d_saver.save(session, os.path.join(model_D_dir, "D.ckpt"), global_step=steps) print("Save Model G ...") g_saver.save(session, os.path.join(model_G_dir, "G.ckpt"), global_step=steps)
def train(): inputs, labels, global_step, \ res_loss, res_optim, seq_len, res_acc, res_decoded, \ net_g = neural_networks() curr_dir = os.path.dirname(__file__) model_dir = os.path.join(curr_dir, MODEL_SAVE_NAME) if not os.path.exists(model_dir): os.mkdir(model_dir) model_G_dir = os.path.join(model_dir, "TG") model_R_dir = os.path.join(model_dir, "R16") if not os.path.exists(model_R_dir): os.mkdir(model_R_dir) if not os.path.exists(model_G_dir): os.mkdir(model_G_dir) init = tf.global_variables_initializer() with tf.Session() as session: session.run(init) r_saver = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='RES'), sharded=True) g_saver = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='TRIM_G'), sharded=False) ckpt = tf.train.get_checkpoint_state(model_G_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model G...") g_saver.restore(session, ckpt.model_checkpoint_path) ckpt = tf.train.get_checkpoint_state(model_R_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model R...") r_saver.restore(session, ckpt.model_checkpoint_path) AllLosts = {} while True: errA = errD1 = errD2 = 1 batch_size = 4 for batch in range(BATCHES): if len(AllLosts) > 10 and random.random() > 0.7: sorted_font = sorted(AllLosts.items(), key=operator.itemgetter(1), reverse=True) font_info = sorted_font[random.randint(0, 10)] font_info = font_info[0].split(",") train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size, False, \ font_info[0], int(font_info[1]), int(font_info[2]), int(font_info[3])) else: # train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res(batch_size, False, _font_size=36) train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res( batch_size) # feed = {inputs: train_inputs, labels: train_labels, seq_len: train_seq_len} start = time.time() p_net_g = session.run(net_g, {inputs: train_inputs}) p_net_g = np.squeeze(p_net_g, axis=3) for i in range(batch_size): _t_img = utils.unsquare_img(p_net_g[i], image_height) _t_img = utils.cvTrimImage(_t_img) _t_img[_t_img < 0] = 0 _t_img = utils.resize(_t_img, image_height) if _t_img.shape[0] * _t_img.shape[ 1] <= image_size * image_size: p_net_g[i] = utils.square_img( _t_img, np.zeros([image_size, image_size]), image_height) feed = { inputs: p_net_g, labels: train_labels, seq_len: train_seq_len } errR, acc, _, steps = session.run( [res_loss, res_acc, res_optim, global_step], feed) font_info = train_info[0][0] + "/" + train_info[0][ 1] + " " + train_info[1][0] + "/" + train_info[1][1] print( "%d time: %4.4fs, res_acc: %.4f, res_loss: %.4f, info: %s " % (steps, time.time() - start, acc, errR, font_info)) if np.isnan(errR) or np.isinf(errR): print("Error: cost is nan or inf") return # 如果正确率低于90%,保存出来 if acc < 0.9: for i in range(batch_size): _img = np.vstack( (train_inputs[i] * 255, p_net_g[i] * 255)) cv2.imwrite( os.path.join(curr_dir, "test", "E%s_%s_%s.png" % (acc, steps, i)), _img) for info in train_info: key = ",".join(info) if key in AllLosts: AllLosts[key] = AllLosts[key] * 0.95 + errR * 0.05 else: AllLosts[key] = errR # 报告 if steps > 0 and steps % REPORT_STEPS == 0: train_inputs, train_labels, train_seq_len, train_info = get_next_batch_for_res( batch_size) p_net_g = session.run(net_g, {inputs: train_inputs}) p_net_g = np.squeeze(p_net_g, axis=3) for i in range(batch_size): _t_img = utils.unsquare_img(p_net_g[i], image_height) _t_img_bin = np.copy(_t_img) _t_img_bin[_t_img_bin <= 0.3] = 0 _t_img = utils.dropZeroEdges(_t_img_bin, _t_img, min_rate=0.1) _t_img = utils.resize(_t_img, image_height) if _t_img.shape[0] * _t_img.shape[ 1] <= image_size * image_size: p_net_g[i] = utils.square_img( _t_img, np.zeros([image_size, image_size]), image_height) decoded_list = session.run(res_decoded[0], { inputs: p_net_g, seq_len: train_seq_len }) for i in range(batch_size): _img = np.vstack((train_inputs[i], p_net_g[i])) cv2.imwrite( os.path.join(curr_dir, "test", "%s_%s.png" % (steps, i)), _img * 255) original_list = utils.decode_sparse_tensor(train_labels) detected_list = utils.decode_sparse_tensor(decoded_list) if len(original_list) != len(detected_list): print("len(original_list)", len(original_list), "len(detected_list)", len(detected_list), " test and detect length desn't match") print("T/F: original(length) <-------> detectcted(length)") acc = 0. for idx in range( min(len(original_list), len(detected_list))): number = original_list[idx] detect_number = detected_list[idx] hit = (number == detect_number) print("%6s" % hit, list_to_chars(number), "(", len(number), ")") print("%6s" % "", list_to_chars(detect_number), "(", len(detect_number), ")") # 计算莱文斯坦比 import Levenshtein acc += Levenshtein.ratio(list_to_chars(number), list_to_chars(detect_number)) print("Test Accuracy:", acc / len(original_list)) sorted_fonts = sorted(AllLosts.items(), key=operator.itemgetter(1), reverse=True) for f in sorted_fonts[:20]: print(f) print("Save Model R ...") r_saver.save(session, os.path.join(model_R_dir, "R.ckpt"), global_step=steps) try: ckpt = tf.train.get_checkpoint_state(model_G_dir) if ckpt and ckpt.model_checkpoint_path: print("Restore Model G...") g_saver.restore(session, ckpt.model_checkpoint_path) except: pass