def generate_table(cell, src): # import pickle pos, cols, rows, col_point, row_point, tables, table_shape = cell[0][ 1], cell[1], cell[2], cell[3], cell[4], cell[5], cell[6] col_point = sorted(col_point) row_point = sorted(row_point) tables = sorted(tables, key=lambda i: i[1][3])[:-1] tables = sorted(tables, key=lambda i: i[1][0] + i[1][1]) # 表格内所有单字位置 table_im = src[pos[1]:pos[1] + pos[3], pos[0]:pos[0] + pos[2]] table_line_regions = text_predict(table_im, 1, 1, table_im) torch.cuda.empty_cache() word_list = [] # print('table_line_length', len(table_line_regions)) for region_index, region in enumerate(table_line_regions): region_y = [region[0][1], region[0][5]] region_x = [region[0][0], region[0][2]] # Image.fromarray(region[1]).save(f'1/{region_index}.jpg') content = predict(Image.fromarray(region[1]).convert('L')) torch.cuda.empty_cache() content = (content[0][0], content[0][1], content[1]) for indexi, cont in enumerate(content[1]): if cont[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1][indexi] = [-1] while 1: try: content[1].remove([-1]) except: break x = content[2] content = calculate(content) for index, word in enumerate(content): word_list.append([[ x[index][0] + region_x[0], region_y[0], x[index][1] + region_x[0], region_y[0], x[index][0] + region_x[0], region_y[1], x[index][1] + region_x[0], region_y[1] ], word]) # # 保存表格行列焦点坐标 # show_im = np.ones(table_shape, np.uint8) # import itertools # for x, y in itertools.product([int(i) for i in col_point], [int(i) for i in row_point]): # cv2.circle(show_im, (x, y), 1, (255, 255, 255), 1) # Image.fromarray(show_im).save('show_im.jpg') for i in tables: d = {'col_begin': 0, 'col_end': 0, 'row_begin': 0, 'row_end': 0} for index, value in enumerate(col_point): if index == 0: d_range = 50 else: d_range = (col_point[index] - col_point[index - 1]) / 2 if i[1][0] > col_point[index] - d_range: d['col_begin'] = index for index, value in enumerate(col_point): if index == len(col_point) - 1: d_range = 50 else: d_range = (col_point[index + 1] - col_point[index]) / 2 if i[1][0] + i[1][2] < col_point[index] + d_range: d['col_end'] = index break for index, value in enumerate(row_point): if index == 0: d_range = 50 else: d_range = (row_point[index] - row_point[index - 1]) / 2 if i[1][1] > row_point[index] - d_range: d['row_begin'] = index for index, value in enumerate(row_point): if index == len(row_point) - 1: d_range = 50 else: d_range = (row_point[index + 1] - row_point[index]) / 2 if i[1][1] + i[1][3] < row_point[index] + d_range: d['row_end'] = index break if d['col_begin'] >= d['col_end']: d['col_end'] = d['col_begin'] + 1 if d['row_begin'] >= d['row_end']: d['row_end'] = d['row_begin'] + 1 # print('123'*3, d) i.append(d) # print(pos[0], pos[1], pos[2], pos[3]) # table_im = src[pos[1]:pos[1]+pos[3], pos[0]:pos[0]+pos[2]] # Image.fromarray(table_im).show() # images = text_predict(table_im, 1, 1, table_im) cell_list = [] for row_p in range(len(row_point) - 1): for col_p in range(len(col_point) - 1): roi = table_im[int(row_point[row_p]):int(row_point[row_p + 1]), int(col_point[col_p]):int(col_point[col_p + 1])] cell_list.append([ roi, [ int(col_point[col_p]), int(row_point[row_p]), int(col_point[col_p + 1] - col_point[col_p]), int(row_point[row_p + 1] - int(row_point[row_p])) ], { 'col_begin': col_p, 'col_end': col_p + 1, 'row_begin': row_p, 'row_end': row_p + 1 }, 0 ]) # 判断单元格是否正确检测 for i in tables: col_begin, col_end, row_begin, row_end = \ i[-1]['col_begin'], i[-1]['col_end'], i[-1]['row_begin'], i[-1]['row_end'] for col in range(col_begin, col_end): for row in range(row_begin, row_end): for cell in cell_list: if cell[2]['col_begin'] == col_begin and cell[2]['col_end'] == col_end and\ cell[2]['row_begin'] == row_begin and cell[2]['row_end'] == row_end: cell[-1] = 1 # 没有检测到单元格则赋值 for i in cell_list: if i[-1] == 0: print('not detect cell', i[1:]) tables.append(i[:-1]) # images = text_predict(table_im) # # 单元格位置 # # for cell in tables: # # print(cell[1:]) # # 保存表格图 # save_table = table_im.copy() # # for word in word_list: # # word = word[0] # # cv2.rectangle(save_table, (word[0], word[1]), (word[6], word[7]), (255, 0, 0), 1) # for i in table_line_regions: # print(123456, i[0]) # cv2.rectangle(save_table, (i[0][0] - 1, i[0][1] - 1), (i[0][6] + 1, i[0][7] + 1), (255, 0, 0), 1) # # import random # # for i in tables: # # cv2.rectangle(save_table, (i[1][0], i[1][1]), (i[1][0]+i[1][2], i[1][1]+i[1][3]), (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), 1) # from config_url import DETECT_URL # import requests, base64 # _, img = cv2.imencode('.jpg', table_im) # img = base64.b64encode(img.tostring()) # # data = {'img': img, 'scale_w': scale_w, 'scale_h': scale_h, 'ori_img': ori} # data = {'img': img, 'scale_w': 1, 'scale_h': 1, 'ori_img': img} # crop_area_json = requests.post(DETECT_URL, data=data) # crop_area = [] # # while_i += 1 # if crop_area_json.json() != '': # for i in crop_area_json.json(): # image = base64.b64decode(i[1]) # image = np.fromstring(image, np.uint8) # image = cv2.imdecode(image, cv2.IMREAD_COLOR) # crop_area.append([i[0], image]) # for te in crop_area: # print(2221111, te[0]) # t = te[0] # cv2.rectangle(save_table, (t[0], t[1]), (t[6], t[7]), (0, 0, 255), 1) # Image.fromarray(save_table).save('able1.jpg') # Image.fromarray(table_im).save('able3.jpg') # 去除检测错误的表格单元格 tables_cell = {} for cell in tables: tmp = f"{cell[2]['row_begin']}_{cell[2]['row_end']}_{cell[2]['col_begin']}_{cell[2]['col_end']}" if tmp not in tables_cell.keys(): tables_cell[tmp] = cell[:-1] else: if tables_cell[tmp][1][2] * tables_cell[tmp][1][3] < cell[1][ 2] * cell[1][3]: tables_cell[tmp] = cell[:-1] # for cell in tables_cell: # print(111, cell[1:]) tables = [[ v[0], v[1], { 'row_begin': int(k.split('_')[0]), 'row_end': int(k.split('_')[1]), 'col_begin': int(k.split('_')[2]), 'col_end': int(k.split('_')[3]) } ] for k, v in tables_cell.items()] save_table = table_im.copy() for index_i, i in enumerate(tables): print('cell location: ', i[-1]) cell_region = [i[1][0], i[1][1], i[1][0] + i[1][2], i[1][1] + i[1][3]] cv2.putText(save_table, str(index_i), (cell_region[0] + 2, cell_region[1] + 2), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 1) cv2.rectangle(save_table, (cell_region[0], cell_region[1]), (cell_region[2], cell_region[3]), (255, 0, 0), 1) word_str = [] for word in word_list: word_center_point = ((word[0][0] + word[0][2]) / 2, (word[0][1] + word[0][5]) / 2) if cell_region[0] < word_center_point[0] < cell_region[2] and cell_region[1] < word_center_point[1] < \ cell_region[3]: word_str.append(word) # if i[2]['row_begin'] == 3 and i[2]['row_end'] == 4 and i[2]['col_begin'] == 0 and i[2]['col_end'] == 1: # print(cell_region) # print(word_str) word_str = sorted(word_str, key=lambda x: x[0][1]) # print('word_str', word_str) # print('table', i[2]) # print(i[2], word_str) word_lines = [] word_temp = [] for index, word in enumerate(word_str): if len(word_temp) == 0: word_temp.append(word) if len(word_str) == 1: word_lines.append(word_temp) continue if word[0][1] == word_temp[-1][0][1]: word_temp.append(word) else: word_temp = sorted(word_temp, key=lambda x: x[0][0]) # print(1111, word_temp) word_lines.append(word_temp) word_temp = [word] if index == len(word_str) - 1: if len(word_temp) != 0: # print(2222, word_temp) word_lines.append(word_temp) word_str = '' # new_word_lines = [] # for line in word_lines: # if line in new_word_lines: # print(1111111) # continue # new_word_lines.append(line) # word_lines = new_word_lines.copy() for line in word_lines: # print('line', line) for word in line: word_str += word[1] i.append([word_str, i[1][2], i[1][3]]) Image.fromarray(save_table).save('able1.jpg') # for cell in tables: # # print('*'*5, cell[1:]) # cell_w, cell_h = cell[1][2], cell[1][3] # cell_ims, text = [], '' # for image in images: # image_im, cell_im = image[0], cell[1] # if image_im[0] > cell_im[0]+cell_im[2]: # continue # if image_im[1] > cell_im[1]+cell_im[3]: # continue # if image_im[6] < cell_im[0]: # continue # if image_im[7] < cell_im[1]: # continue # x0, y0, x1, y1 = max(image_im[0], cell_im[0]), max(image_im[1], cell_im[1]), \ # min(image_im[6], cell_im[0]+cell_im[2]), min(image_im[7], cell_im[1]+cell_im[3]) # cell_ims.append([x0, y0, x1, y1]) # for i in cell_ims: # try: # cell_im = table_im[i[1]:i[3], i[0]:i[2]] # content = predict(Image.fromarray(cell_im).convert('L')) # for indexi, i in enumerate(content[1]): # if i[0] > 0.9: # content[0][indexi] = content[0][indexi][0] # content[1][indexi] = [-1] # while 1: # try: # content[1].remove([-1]) # except: # break # content = calculate(content) # # Image.fromarray(j[1]).save('found/{}.jpg'.format(''.join(img_path.split('/')))) # torch.cuda.empty_cache() # text += content # except Exception as ex: # print('ocr error', ex) # continue # cell.append([text, cell_w, cell_h]) # print('cell text:', text) tables = sorted(tables, key=lambda x: x[2]['row_begin']) new_table = [] for i in tables: new_table.append([i[2], i[3]]) return new_table, rows, cols, pos
def get_text(request): img_path = request.get('img_path') par = request.get('par') task_id = request.get('task_id') FT = request.get('FT') page = request.get('page') print(img_path) try: if img_path.lower().endswith('.pdf'): pdf = fitz.open(img_path) page_num = pdf[int(page) - 1] trans = fitz.Matrix(3, 3).preRotate(0) pm = page_num.getPixmap(matrix=trans, alpha=False) ori_img = fourier_demo( Image.frombytes("RGB", [pm.width, pm.height], pm.samples), 'FT001') else: ori_img = fourier_demo( Image.open(img_path).convert('RGB'), 'FT001') f = select(FT[:11] + '001') print('FT:', FT[:11] + '001') # input_img = input_img.resize((2000, 2000), Image.ANTIALIAS) ori_w, ori_h = ori_img.size input_img = ori_img.copy() input_img.thumbnail((2000, 2000), Image.ANTIALIAS) # input_img = input_img.resize((2000, 2000), Image.ANTIALIAS) scale_w, scale_h = input_img.size scale_w, scale_h = ori_w / scale_w, ori_h / scale_h input_img = input_img.convert('RGB') data_image = str( os.path.splitext(img_path)[0].split('/')[-1]) + '_' + str(page) data_image = '/home/ddwork/wce_data/ori_images/{}_{}.jpg'.format( data_image, task_id) input_img.save(data_image) input_img = np.array(input_img) # input_img = seal_eliminate(input_img) import time start = time.time() print("text_predict zhiqian") images = text_predict(input_img, scale_w, scale_h, ori_img) print("text_predict zhihou") torch.cuda.empty_cache() print(111111111111111111111111111, "SAD", time.time() - start, 'HAPPY') start = time.time() # image_positions = [[i[0].tolist(), rec_txt(i[1]).replace('“', '').replace('"', '')] for i # in # images] image_positions = [] for j in images: try: print("predict front!!!!!!!!!!!!!!") content = predict(Image.fromarray(j[1]).convert('L')) print("predict back!!!!!!!!!!!!!!") for index, i in enumerate(content[1]): if i[0] > 0.9: content[0][index] = content[0][index][0] content[1].pop(index) # if i[0] < 0.9: # img = Image.fromarray(j[1]).convert('L') # width, height = img.size[0], img.size[1] # scale = height * 1.0 / 32 # width = int(width / scale) # # img = img.resize([width, 32], Image.ANTIALIAS) # img = np.array(img) # new_img = img[:, (content[2][index] - 1) * 8:(content[2][index] + 2) * 8] # word, prob = attention(new_img) # if prob > 0.9: # content[0][index] = word[0] # content[1].pop(index) # else: # content[0][index] = content[0][index][0] # content[1].pop(index) content = calculate(content) image_positions.append( [j[0], content.replace('“', '').replace('‘', '')]) except Exception as e: print(e) continue # torch.cuda.empty_cache() # data_json[task_id] = [par, data_image, FT, image_positions] data_json = WCE.create(field_id=int(task_id), par=str(par), image_path=data_image, FT=FT, file_type=FT[:11], image_positions=str(image_positions), edited=False, trained=False) data_json.save() print(222222222222222222222222222, time.time() - start) text = single_ocr(image_positions) print(text) # with open(img_path + '.txt', 'w', encoding='utf-8') as fd: # fd.write(text) texts = f.extract_info(img_path, page, FT[:11] + '001', text) print(texts) # try: # found = get_warp(input_img, image_positions, FT) # found_texts = get_texts('warp_templates/{}/template.xml'.format(FT), found, img_path, task_id) # except Exception as e: # print(e) found_texts = '' print( '==================================================================' ) print(texts, found_texts) torch.cuda.empty_cache() # 资质证书判断 if FT[:11] == 'FT001003110': FT = FT[:8] + texts.get('version') # 路径中取日期 try: if texts.get('发证日期') == '' or not texts.get('发证日期'): import re date_path = re.search( '([0-9]{4}[-/年][0-9]{1,2}[-/月][0-9]{1,2}日?)', os.path.split(img_path)[1]) if date_path: texts['发证日期'] = date_path.groups()[0] except: pass if texts == 'FT999' and found_texts: return { 'result': 'true', 'message': '请求成功', 'taskid': task_id, 'fields': found_texts, 'FT': FT } if texts != 'FT999' and found_texts == '': return { 'result': 'true', 'message': '请求成功', 'taskid': task_id, 'fields': texts, 'FT': FT } if found_texts: for key, value in texts.items(): try: if value == '': texts[key] = found_texts[key] except: continue blank = 0 for key, value in texts.items(): if value == '': blank += 1 if blank == len(texts) - 1: return { 'result': 'false', 'message': '请求失败', 'taskid': task_id, 'fields': {}, 'FT': 'FT999999999' } else: return { 'result': 'true', 'message': '请求成功', 'taskid': task_id, 'fields': texts, 'FT': FT } except Exception as e: print(e) return { 'result': 'false', 'message': '请求失败', 'taskid': task_id, 'fields': {}, 'FT': 'FT999999999' }
import numpy as np from PIL import Image import fitz from full_ocr_local import single_ocr # from template_warp import get_warp import json import pickle #from sanic import Sanic, response a = Image.open('test.png').convert('RGB') ori_w, ori_h = a.size b = a.resize((1000, 1000)) scale_w, scale_h = b.size b = np.array(b) scale_w, scale_h = ori_w / scale_w, ori_h / scale_h text_predict(b, scale_w, scale_h, a) #app = Sanic(__name__) # with open('info.pkl', 'rb') as pkl: # info = pickle.load(pkl) # info = pickle.load(open('info.pkl', 'rb')) large_FT = [ 'FT001001001002', 'FT001001001003', 'FT001001001005', 'FT001003002001', 'FT001001002001', 'FT001001003001' ] # @app.route('/fre', methods=['POST']) # def get_text(request): # img_path = request.form.get('input')
def fast_ocr(request): ori_time = time.time() ori_start_time = datetime.now() print('start...') img_path = request.form.get('img_path') print(img_path) print(request.form.get('position')) position = '[' + request.form.get('position') + ']' rotate = int(request.form.get('rotate')) page = request.form.get('pageNum', 1) # FT = request.form.get('FT', None) # file_type = request.form.get('file_type', None) # par_code = request.form.get('par_code', None) # project_id = request.form.get('project_id', None) # # with open('/home/ddwork/projects/compound_log/project_infos/fast_ocr.log', 'a', encoding='utf-8') as f: # f.write(str(FT) + '\t' + str(file_type) + '\t' + str(par_code) + '\t' + str(project_id) + '\t' + str(img_path) + '\t' + str(page) + '\n') print(page) position = eval(position) if img_path.lower().endswith('pdf'): image_w = int(request.form.get('imageW')) image_h = int(request.form.get('imageH')) pdf = fitz.open(img_path) page = pdf[int(page) - 1] trans = fitz.Matrix(3, 3).preRotate(0) pm = page.getPixmap(matrix=trans, alpha=False) img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) img = np.array(img) else: img = np.array(Image.open(img_path).convert('RGB')) image_h, image_w = img.shape[:-1] print('原图:', img.shape) # crop_img = img[position[1]:position[3], position[0]:position[2]] # crop_img = np.array(crop_img) # crop_img = rotate_img(Image.fromarray(crop_img), rotate).convert('L') # ori_img = np.array(crop_img) # img = rotate_img(Image.fromarray(img), rotate).convert('L') # img = Image.fromarray(img) print('rotate', rotate) # Image.fromarray(img).save('11111111.jpg') img_h, img_w, c = img.shape position[0] = position[0] if position[0] > 0 else 0 position[1] = position[1] if position[0] > 0 else 0 ori_img = img[int(position[1] * img_h / image_h):int(position[3] * img_h / image_h), int(position[0] * img_w / image_w):int(position[2] * img_w / image_w)] ori_img = rotate_img(Image.fromarray(ori_img), rotate).convert('L') # ori_w, ori_h = ori_img.size crop_img = np.array(ori_img.convert('RGB')) # Image.fromarray(crop_img).save('11111111.jpg') table_infos = [] # 如果框选高宽比过大,则不考虑表格 # TODO 判断可能有问题 # print(111111111111111111, image.shape[0], image.shape[1]) start_table_time = time.time() if crop_img.shape[1] / crop_img.shape[0] > 3 and crop_img.shape[ 1] / np.array(ori_img).shape[1] < 0.3: print('判断不走表格!') pass else: try: # 判断是否为表格 # from config_url import TABLE_URL # import base64, requests # retval, buffer = cv2.imencode('.jpg', crop_img) # pic_str = base64.b64encode(buffer) # pic_str = pic_str.decode() # r = requests.post(TABLE_URL, data={"img": pic_str}) # img_byte = base64.b64decode(r.content.decode("utf-8")) # img_np_arr = np.fromstring(img_byte, np.uint8) # src = cv2.imdecode(img_np_arr, cv2.IMREAD_COLOR) tables = extract_table(crop_img) texts_table = [] if tables: if tables != 'not table': for table in tables: table_time = time.time() texts_table.append( ['table', generate_table(table, crop_img)]) print('generate_table time is ', time.time() - table_time) for table in texts_table: cell_info = [] table_info = [['' for row in range(table[1][2])] for col in range(table[1][1])] for tb in table[1][0]: d = tb[0] for row in range(d['row_begin'], d['row_end']): for col in range(d['col_begin'], d['col_end']): try: table_info[row][col] += tb[1][0] if d not in cell_info: cell_info.append(d) except: print('cell error') print(f'###start{str(table_info)}end###') x0, y0, x1, y1 = table[-1][-1][0], table[-1][-1][1], table[-1][-1][0]+table[-1][-1][2], \ table[-1][-1][1]+table[-1][-1][3] new_cell_info = [] for cell in cell_info: if cell['row_end'] - cell['row_begin'] == 1 and cell[ 'col_end'] - cell['col_begin'] == 1: continue new_cell_info.append( [[cell['row_begin'], cell['col_begin']], [cell['row_end'] - 1, cell['col_end'] - 1]]) cell_info = new_cell_info table_infos.append([[ [x0, y0, x1, y1], [x0, y0, x1, y1] ], f'###start{str(table_info)}******{str(cell_info)}end###' ]) # return response.text(f'###start{str(table_info)}end###') except Exception as ex: print('table error', ex) print('table detect time is ', time.time() - start_table_time) # crop_img = cv2.copyMakeBorder(crop_img, int(image_h / 2), int(image_h / 2), int(image_w / 2), int(image_w / 2), # cv2.BORDER_REPLICATE) # short_size = 640 # h, w = crop_img.shape[:2] # short_edge = min(h, w) # if short_edge < short_size: # # 保证短边 >= inputsize # scale = short_size / short_edge # if scale > 1: # crop_img = cv2.resize(crop_img, dsize=None, fx=scale, fy=scale) # ori_img = np.array(ori_img) # _, ori = cv2.imencode('.jpg', ori_img) # ori = base64.b64encode(ori.tostring()) crop_img = Image.fromarray(crop_img) while_i = 0 st_time = time.time() # crop_area = [] while 1: crop_img.thumbnail((1500 - while_i * 100, 1500 - while_i * 100), Image.ANTIALIAS) # crop_img = crop_img.resize((1500, 1500)) # scale_w, scale_h = crop_img.size # scale_w, scale_h = ori_w / scale_w, ori_h / scale_h # crop_img = crop_img.resize((1000, 1000)) # crop_img.save('111.jpg') crop_img = np.array(crop_img) # _, img = cv2.imencode('.jpg', crop_img) # img = base64.b64encode(img.tostring()) # data = {'img': img, 'scale_w': 1, 'scale_h': 1, 'ori_img': img} # crop_area_json = requests.post(DETECT_URL, data=data) # while_i += 1 # if crop_area_json.json() != '': # for i in crop_area_json.json(): # image = base64.b64decode(i[1]) # image = np.fromstring(image, np.uint8) # image = cv2.imdecode(image, cv2.IMREAD_COLOR) # crop_area.append([i[0], image]) # break crop_area = text_predict(crop_img, 1, 1, crop_img) torch.cuda.empty_cache() break print('ctpn time: ', time.time() - st_time, ' counts: ', len(crop_area)) # Image.fromarray(crop_img).show() # Image.fromarray(crop_area[0][1]).show() # save_img = crop_img.copy() # for te in crop_area: # # print(1111, te[0]) # t = te[0] # cv2.rectangle(save_img, (t[0], t[1]), (t[6], t[7]), (255, 0, 0), 1) # Image.fromarray(save_img).save('able2.jpg') # # from pan.predict import text_predict # img_save = crop_img.copy() # sss = text_predict(img_save, 1, 1, img_save) # for i in sss: # print(123456, i[0]) # cv2.rectangle(img_save, (i[0][0] - 1, i[0][1] - 1), (i[0][6] + 1, i[0][7] + 1), (255, 0, 0), 1) # Image.fromarray(img_save).save('able4.jpg') new_results = [] for index, j in enumerate(crop_area): # image_positions = [[i[0].tolist(), rec_txt(i[1]).replace('“', '').replace('"', '')] for i # in # images] try: # _, img = cv2.imencode('.jpg', j[1]) # img = base64.b64encode(img.tostring()) # data = {'img': img} # content = requests.post(RECOGNISE_URL, data=data).json()[:2] content, _ = predict(Image.fromarray(j[1]).convert('L')) for indexi, i in enumerate(content[1]): if i[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1][indexi] = [-1] while 1: try: content[1].remove([-1]) except: break print(content) content = calculate(content) # Image.fromarray(j[1]).save('found/{}.jpg'.format(''.join(img_path.split('/')))) # torch.cuda.empty_cache() print(content) new_results.append([j[0], content]) except Exception as ex: print(ex) continue # torch.cuda.empty_cache() # data_json[task_id] = [par, data_image, FT, image_positions] document = '' new_results = sorted(new_results, key=lambda i: i[0][1]) line_images = [] cut_index = 0 curr_index = 0 print(2222222222, len(new_results)) for index, i in enumerate(new_results): try: if index == len(new_results) - 1: # print(cut_index) if cut_index < index: line_images.append(new_results[cut_index:]) else: line_images.append(new_results[index:]) break # if abs(new_results[index + 1][0][1] - new_results[index][0][1]) > ( # new_results[index][0][7] - new_results[index][0][1]) * 4 / 5: # line_images.append(new_results[cut_index: index + 1]) # cut_index = index + 1 if abs(new_results[index + 1][0][1] - new_results[curr_index][0][1] ) < (new_results[curr_index][0][7] - new_results[curr_index][0][1]) * 3 / 4: for result in new_results[cut_index:index + 1]: if count_area(new_results[index + 1], result) > (result[0][6] - result[0][0]) / 2: line_images.append(new_results[cut_index:index + 1]) cut_index = index + 1 curr_index = index + 1 continue else: line_images.append(new_results[cut_index:index + 1]) cut_index = index + 1 curr_index = index + 1 except: continue for index, i in enumerate(line_images): line_images[index] = sorted(i, key=lambda a: a[0][0] + a[0][1]) texts = [] for i in line_images: text = '' for index, j in enumerate(i): try: if index == len(i) - 1: text += j[1] elif abs(i[index + 1][0][6] - i[index][0][6]) > 3 * (abs( i[index][0][6] - i[index][0][0]) / len(i[index][1])): text += j[1] + ' ' else: text += j[1] except: continue texts.append([[i[0][0], i[-1][0]], text]) crop_w = crop_img.shape[1] document = layout(texts, crop_w, table_infos) # print(document) # for i in texts: # print(11111, i) # document += i[1] + '\n' if document == '': # document = rec_txt(np.array(ori_img.convert('L'))).replace('“', '').replace('‘', '') # torch.cuda.empty_cache() try: # _, img = cv2.imencode('.jpg', ori_img) # img = base64.b64encode(img.tostring()) # data = {'img': img} # content = requests.post('http://172.30.81.191:32010/predict', data=data).json()[:2] # document = content content, _ = predict(Image.fromarray(ori_img).convert('L')) for indexi, i in enumerate(content[1]): if i[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1].pop(indexi) document = calculate(content) # torch.cuda.empty_cache() except: pass print('ddddddddddddddd', document) if document == ([], []): document = '' ori_end_time = datetime.now() ori_return = json.dumps([document]) print('ori_time;', time.time() - ori_time, '\n', 'ori_start_time:', ori_start_time, '\n', 'ori_end_time:', ori_end_time) return response.text(ori_return)
def generate_table(cell): # import pickle # pickle.dump(cell, open('table.pkl', 'wb')) pos, cols, rows, col_point, row_point, tables = cell[0][1], cell[1], cell[ 2], cell[3], cell[4], cell[5] col_point = sorted(col_point) row_point = sorted(row_point) tables = sorted(tables, key=lambda i: i[1][3])[:-1] tables = sorted(tables, key=lambda i: i[1][0] + i[1][1]) # print(tables[14][1]) # Image.fromarray(tables[14][0]).show() # table = document.add_table(rows, cols, style='Table Grid') for i in tables: d = {'col_begin': 0, 'col_end': 0, 'row_begin': 0, 'row_end': 0} for index, value in enumerate(col_point): if index == 0: d_range = 50 else: d_range = (col_point[index] - col_point[index - 1]) / 2 if i[1][0] > col_point[index] - d_range: # print(33333333333, i[1], index) d['col_begin'] = index for index, value in enumerate(col_point): if index == len(col_point) - 1: d_range = 50 else: d_range = (col_point[index + 1] - col_point[index]) / 2 if i[1][0] + i[1][2] < col_point[index] + d_range: d['col_end'] = index break for index, value in enumerate(row_point): if index == 0: d_range = 50 else: d_range = (row_point[index] - row_point[index - 1]) / 2 if i[1][1] > row_point[index] - d_range: d['row_begin'] = index for index, value in enumerate(row_point): if index == len(row_point) - 1: d_range = 50 else: d_range = (row_point[index + 1] - row_point[index]) / 2 if i[1][1] + i[1][3] < row_point[index] + d_range: d['row_end'] = index break i.append(d) for index, i in enumerate(tables): texts = '' try: # new_i = cv2.resize(i[0], (i[1][2] * 2, i[1][3] * 2)) # new_i = cv2.copyMakeBorder(i[0], int(i[1][3] / 2), int(i[1][3] / 2), int(i[1][2] / 2), int(i[1][2] / 2), cv2.BORDER_REPLICATE) i[0] = Image.fromarray(i[0]) new_i = i[0].copy() ori_w, ori_h = i[0].size new_i.thumbnail((1500, 1500), Image.ANTIALIAS) scale_w, scale_h = new_i.size scale_w, scale_h = ori_w / scale_w, ori_h / scale_h new_i = np.array(new_i.convert('RGB')) # print(type(new_i)) # Image.fromarray(new_i).save('core.jpg') if new_i.shape[1] > 16 and new_i.shape[0] > 16: images = text_predict(new_i, scale_w, scale_h, np.array(i[0])) torch.cuda.empty_cache() # images = text_predict(new_i) else: i.append([texts, i[1][2], i[1][3]]) continue # torch.cuda.empty_cache() if images: for image in sorted(images, key=lambda ii: ii[0][1]): # texts += rec_txt(image[1]).replace('“', '') # print('1111111111', predict(Image.fromarray(image[1]).convert('L'))) content = predict(Image.fromarray(image[1]).convert('L')) ori_content = [i[0] for i in content[0]] prob_content = [[i, j] for i, j in zip(content[0], content[1])] for indexi, cont in enumerate(content[1]): if cont[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1][indexi] = [-1] while 1: try: content[1].remove([-1]) except: break # for indexa, ia in enumerate(content[1]): # if ia[0] < 0.9: # content[0][index] = content[0][index][0] # content[1].pop(index) # img = Image.fromarray(image[1]).convert('L') # width, height = img.size[0], img.size[1] # scale = height * 1.0 / 32 # width = int(width / scale) # # img = img.resize([width, 32], Image.ANTIALIAS) # img = np.array(img) # new_img = img[:, (content[2][indexa] - 1) * 8:(content[2][indexa] + 2) * 8] # word, prob = attention(new_img) # # print('prob', prob) # if prob > 0.9: # content[0][indexa] = word # content[1].pop(indexa) # else: # content[0][indexa] = content[0][indexa][0] # content[1].pop(indexa) content = calculate(content) # print('content', content) texts += content elif new_i.any() and new_i.shape[0] < new_i.shape[1] * 1.5: # texts += rec_txt(new_i).replace('“', '') try: content = predict(Image.fromarray(new_i).convert('L')) ori_content = [i[0] for i in content[0]] prob_content = [[i, j] for i, j in zip(content[0], content[1])] for indexi, cont in enumerate(content[1]): if cont[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1][indexi] = [-1] while 1: try: content[1].remove([-1]) except: break # for indexa, ia in enumerate(content[1]): # if ia[0] < 0.9: # img = Image.fromarray(new_i).convert('L') # width, height = img.size[0], img.size[1] # scale = height * 1.0 / 32 # width = int(width / scale) # # img = img.resize([width, 32], Image.ANTIALIAS) # img = np.array(img) # new_img = img[:, (content[2][indexa] - 1) * 8:(content[2][indexa] + 2) * 8] # word, prob = attention(new_img) # if prob > 0.9: # content[0][indexa] = word # content[1].pop(indexa) # else: # content[0][indexa] = content[0][indexa][0] # content[1].pop(indexa) content = calculate(content) texts += content except Exception as ex: print('small_image_warning', ex) # torch.cuda.empty_cache() # else: # i.append([texts, i[1][2], i[1][3]]) # print('12345', texts, '***', i[1]) i.append([texts, i[1][2], i[1][3]]) # print('54321') except Exception as e: print('table_text warning', e) new_table = [] for i in tables: new_table.append([i[2], i[3]]) return new_table, rows, cols, pos
def single_ocr(document, img_name, start_page, new_url): img_name = skew_detect.get_rotated_img(img_name) ori_img = np.array(img_name) ori_w, ori_h = img_name.size img_name.thumbnail((1500, 1500), Image.ANTIALIAS) # img_name = img_name.resize((1500, 1500), Image.ANTIALIAS) # img_name = img_name.convert('RGB') scale_w, scale_h = img_name.size # print(scale_w, scale_h) scale_w, scale_h = ori_w / scale_w, ori_h / scale_h print('原图大小:', ori_w, ori_h, '缩放比例:', scale_w, scale_h) img = np.array(img_name) # B_channel, G_channel, R_channel = cv2.split(img) # cv2.imwrite('test.png', R_channel) # img = cv2.cvtColor(R_channel, cv2.COLOR_GRAY2BGR) start = time.time() images = text_predict(img, scale_w, scale_h, ori_img) torch.cuda.empty_cache() print('ctpn time: ', time.time() - start) # new_images = [] # images = new_images # Image.fromarray(img).save('paragraph.jpg') # Image.fromarray(img).show() try: tables = extract_table(ori_img) if tables == 'not table': has_table = False else: has_table = True # for table in tables: # table[0][1][1] = table[0][1][1] / scale_h # table[0][1][3] = table[0][1][3] / scale_h except: has_table = False print(2222222222222222222222222, has_table) results = [] start = time.time() for index, j in enumerate(images): # if j[1].any() and j[1].shape[0] < j[1].shape[1] * 1.5: try: if has_table: count = 0 for table in tables: if table[0][1][1] + table[0][1][3] > j[0][1] > table[0][1][ 1]: continue else: count += 1 if count == len(tables): content = predict(Image.fromarray(j[1]).convert('L')) ori_content = [i[0] for i in content[0]] prob_content = [[i, j] for i, j in zip(content[0], content[1])] for indexi, i in enumerate(content[1]): if i[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1][indexi] = [-1] while 1: try: content[1].remove([-1]) except: break # ori_content = [i[0] for i in content[0]] # with open(os.path.splitext(new_url)[0] + '.txt', 'a', encoding='utf-8') as f: # for index, i in enumerate(content[1]): # if i[0] > 0.9: # content[0][index] = content[0][index][0] # content[1].pop(index) # if i[0] < 0.9: # img = Image.fromarray(j[1]).convert('L') # width, height = img.size[0], img.size[1] # scale = height * 1.0 / 32 # width = int(width / scale) # # img = img.resize([width, 32], Image.ANTIALIAS) # img = np.array(img) # new_img = img[:, (content[2][index] - 1) * 8:(content[2][index] + 2) * 8] # word, prob = attention(new_img) # if prob > 0.9: # content[0][index] = word[0] # content[1].pop(index) # else: # content[0][index] = content[0][index][0] # content[1].pop(index) content = calculate(content) # for i, j_i in zip(ori_content, content): # if j_i != i: # f.write(i + '------------>' + j_i + '\n') # content = rec_txt(j[1]) # torch.cuda.empty_cache() results.append( [j[0], content.replace('“', '').replace('‘', '')]) else: content = predict(Image.fromarray(j[1]).convert('L')) ori_content = [i[0] for i in content[0]] prob_content = [[i, j] for i, j in zip(content[0], content[1])] for indexi, i in enumerate(content[1]): if i[0] > 0.9: content[0][indexi] = content[0][indexi][0] content[1][indexi] = [-1] while 1: try: content[1].remove([-1]) except: break # ori_content = [i[0] for i in content[0]] # with open(os.path.splitext(new_url)[0] + '.txt', 'a', encoding='utf-8') as f: # for index, i in enumerate(content[1]): # if i[0] > 0.9: # content[0][index] = content[0][index][0] # content[1].pop(index) # if i[0] < 0.9: # img = Image.fromarray(j[1]).convert('L') # width, height = img.size[0], img.size[1] # scale = height * 1.0 / 32 # width = int(width / scale) # # img = img.resize([width, 32], Image.ANTIALIAS) # img = np.array(img) # new_img = img[:, (content[2][index] - 1) * 8:(content[2][index] + 2) * 8] # word, prob = attention(new_img) # if prob > 0.9: # content[0][index] = word[0] # content[1].pop(index) # else: # content[0][index] = content[0][index][0] # content[1].pop(index) content = calculate(content) # for i, j_i in zip(ori_content, content): # if j_i != i: # f.write(i + '------------>' + j_i + '\n') # content = rec_txt(j[1]) # torch.cuda.empty_cache() results.append( [j[0], content.replace('“', '').replace('‘', '')]) except Exception as e: print(e) continue torch.cuda.empty_cache() print(33333333333333333, time.time() - start) results = sorted(results, key=lambda i: i[0][1]) new_results = results line_images = [] cut_index = 0 curr_index = 0 for index, i in enumerate(new_results): if index == len(new_results) - 1: if cut_index < index: line_images.append(new_results[cut_index:index]) line_images.append(new_results[index:]) else: line_images.append(new_results[index:]) break # if abs(new_results[index + 1][0][1] - new_results[index][0][1]) > ( # new_results[index][0][7] - new_results[index][0][1]) * 4 / 5: # line_images.append(new_results[cut_index: index + 1]) # cut_index = index + 1 if abs(new_results[index + 1][0][1] - new_results[curr_index][0][1] ) > (new_results[curr_index][0][7] - new_results[curr_index][0][1]) * 4 / 5: line_images.append(new_results[cut_index:index + 1]) cut_index = index + 1 curr_index = index + 1 for index, i in enumerate(line_images): line_images[index] = sorted(i, key=lambda a: a[0][0]) texts = [] position = [] for i in line_images: text = '' for index, j in enumerate(i): try: position.append([j[0], j[1]]) if index == len(i) - 1: text += j[1] elif abs(i[index + 1][0][0] - i[index][0][6]) > 3 * (abs( i[index][0][6] - i[index][0][0]) / len(i[index][1])): text += j[1] + ' ' else: text += j[1] except: continue texts.append([[i[0][0], i[-1][0]], text]) print(img_name.size) if has_table: for table in tables: table_index = 0 for index, i in enumerate(texts): # print(i) # print(type(i[0]), type(table[1][1])) if i[0] == 'table': # print(table[0][1]) if table[0][1][1] > i[1][3][1]: table_index = index + 1 elif table[0][1][1] > i[0][0][1]: table_index = index + 1 try: texts.insert(table_index, ['table', generate_table(table, ori_img)]) except Exception as e: print(e) continue # import pickle # pickle.dump(texts, open('texts.pkl', 'wb')) try: texts = sort_paragraph(Image.fromarray(ori_img), texts) except Exception as e: print(e) return document, position document = save2docx(document, texts, Image.fromarray(ori_img), start_page) return document, position