def test_bktree(): test_passed = True test_database = pickle.load( open( 'testdatabase.p', 'rb' ) ) test_database = test_database['bktree'] keys = test_database['keys'] result_target = test_database['result'] query = test_database['query'] BKT = BKTree(levenshtein_distance_DP) for key in keys: BKT.insert(key) result = BKT.get(query) for (word,distance) in result: encounter = 0 for (word_target,distance_target) in result_target: if word == word_target: if not distance == distance_target: test_passed = False encounter +=1 if not encounter == 1: test_passed = False return test_passed
def __init__(self, zero_to_alpha): #self.freq = {} #self.successor = {} self.word_successor = {} self.dictionary = set() self.stopSymbols = [] self.zero_to_alpha = zero_to_alpha self.bk_tree = BKTree()
def deserialize(self, path): with open(path) as inFile: tmp_dic = json.load(inFile) self.dictionary = tmp_dic["Dictionary"] #self.freq = tmp_dic["Freq"] #self.successor = tmp_dic["Successor"] self.word_successor = tmp_dic["Word_successor"] self.bk_tree = BKTree() self.bk_tree.root = tmp_dic["BK_Tree"]
def test_bk_nearest_neighbor_search(): """ Test BK_Nearest_Neighbor_Search function. """ # --- Preparations string_list = [ 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten' ] tree = BKTree(string_list) # --- Exercise functionality search1 = bk_nearest_neighbor_search('eight', tree) search2 = bk_nearest_neighbor_search('ter', tree) search3 = bk_nearest_neighbor_search('123456789', tree) search4 = bk_nearest_neighbor_search('ffff', tree) # --- Check results assert len(search1) == 2 assert search1[0] == 0 assert search1[1:] == ['eight'] assert len(search2) == 2 assert search2[0] == 1 assert search2[1:] == ['ten'] assert len(search3) == 11 assert search3[0] == 9 for value in search3[1:]: assert value in string_list assert len(search4) == 3 assert search4[0] == 3 for value in search4[1:]: assert value in ['five', 'four']
def main(): og_ids = [ "1", "2", "18", # Cellosaurus with relavent terms for human biology "5", "7", "9", "19" ] ogs = [load_ontology.load(x)[0] for x in og_ids] str_to_terms = defaultdict(lambda: []) print("Gathering all term string identifiers in ontologies...") string_identifiers = set() for og in ogs: for id, term in og.id_to_term.items(): str_to_terms[term.name].append([term.id, "TERM_NAME"]) string_identifiers.add(term.name) for syn in term.synonyms: str_to_terms[syn.syn_str].append([term.id, "SYNONYM_%s" % syn.syn_type]) string_identifiers.add(syn.syn_str) print("Building the BK-Tree...") bk_tree = BKTree(string_metrics.bag_dist_multiset, string_identifiers) # with open("fuzzy_match_bk_tree.pickle", "w") as f: with open("fuzzy_match_bk_tree.pickle", "wb") as f: pickle.dump(bk_tree, f) with open("fuzzy_match_string_data.json", "w") as f: f.write(json.dumps(str_to_terms, indent=4, separators=(',', ': ')))
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="Build BK-tree for fuzzy matching.\n" "Example of ID specification:\n\thuman: 1,2,18,5,7,9,19\n" "\tarabidopsis: 20,21") parser.add_argument("-i", "--ids", required=True, help="comma separated list of ontology IDs.") parser.add_argument( "-j", "--json_filename", required=True, help="filename of output json." "'fuzzy_match_string_data_SPECIES.json' is recommended.") parser.add_argument("-p", "--pickle_filename", required=True, help="filename of output pickle" "'fuzzy_match_bk_tree_SPECIES.pickle' is recommended.") args = parser.parse_args() og_ids = args.ids.split(",") ogs = [load_ontology.load(x)[0] for x in og_ids] str_to_terms = defaultdict(lambda: []) print("Gathering all term string identifiers in ontologies...") string_identifiers = set() for og in ogs: for id, term in og.id_to_term.items(): str_to_terms[term.name].append([term.id, "TERM_NAME"]) string_identifiers.add(term.name) for syn in term.synonyms: str_to_terms[syn.syn_str].append( [term.id, "SYNONYM_%s" % syn.syn_type]) string_identifiers.add(syn.syn_str) print("Building the BK-Tree...") bk_tree = BKTree(string_metrics.bag_dist_multiset, string_identifiers) with open(args.pickle_filename, "wb") as f: pickle.dump(bk_tree, f) with open(args.json_filename, "w") as f: f.write(json.dumps(str_to_terms, indent=4, separators=(',', ': ')))
def __init__(self, proxy_map): super(SpecificWorker, self).__init__(proxy_map) self.timer.timeout.connect(self.compute) self.Period = 200 self.timer.start(self.Period) # load the pre-trained EAST text detector print "[INFO] loading EAST text detector..." self.net = cv2.dnn.readNet(NET_FILE) self.model = crnn.CRNN(32, 1, 37, 256) if torch.cuda.is_available(): self.model = self.model.cuda() print "[INFO] loading CRNN text recognizer..." self.model.load_state_dict(torch.load(MODEL_FILE)) self.tree = None if self.use_lexicon: print "[INFO] loading generic english lexicon..." lexicon = [] with open(LEXICON_FILE) as f: for line in f.read().splitlines(): lexicon.append(line.lower()) print "Length of the lexicon: ", len(lexicon) self.tree = BKTree(hamming_distance, lexicon)
def __init__( self, filepath_corpus, filepath_bk_tree, filepath_token_cnts, vocab_path, ngram_range=(2, 3), token_pattern=r"(?u)\b\w\w+\b", preprocessor=None, tokenizer=None, lowercase=True, stop_words=None #list of stopwords #keep it to None ): self.filepath_corpus = filepath_corpus self.filepath_bk_tree = filepath_bk_tree self.vocab_path = vocab_path self.ngram_range = ngram_range self.stop_words = stop_words self.lowercase = lowercase self.preprocessor = preprocessor self.tokenizer = tokenizer self.token_pattern = token_pattern with open(vocab_path, 'rb') as f: word2idx, idx2word = pickle.load(f) self.vocab_size = len(word2idx) try: with open(self.filepath_bk_tree, 'rb') as f: self.tree = pickle.load(f) except Exception as e: pickle_dump(BKTree(items_dict=word2idx), self.filepath_bk_tree) with open(self.filepath_bk_tree, 'rb') as f: self.tree = pickle.load(f) self.token_cnts = None self.bi_pre_token_cnts = None self.bi_post_token_cnts = None self.tri_token_cnts = None self.filepath_token_cnts = None self.filepath_bi_pre_token_cnts = None self.filepath_bi_post_token_cnts = None self.filepath_tri_token_cnts = None if filepath_token_cnts is not None and os.path.exists( filepath_token_cnts): print('LOADING...') with io.open(filepath_token_cnts, encoding='utf-8') as f: self.token_cnts = json.load(f) with io.open('bi_pre_' + filepath_token_cnts, encoding='utf-8') as f: self.bi_pre_token_cnts = json.load(f) with io.open('bi_post_' + filepath_token_cnts, encoding='utf-8') as f: self.bi_post_token_cnts = json.load(f) with io.open('tri_' + filepath_token_cnts, encoding='utf-8') as f: self.tri_token_cnts = json.load(f) else: print('GETTING FROM CORPUS...') self.token_cnts = defaultdict() self.bi_pre_token_cnts = defaultdict() self.bi_post_token_cnts = defaultdict() self.tri_token_cnts = defaultdict() self.filepath_token_cnts = filepath_token_cnts self.filepath_bi_pre_token_cnts = 'bi_pre_' + filepath_token_cnts self.filepath_bi_post_token_cnts = 'bi_post_' + filepath_token_cnts self.filepath_tri_token_cnts = 'tri_' + filepath_token_cnts self.build_ngrams_count_from_corpus()
print("Constructing dataset........", end='', flush=True) dataset = Dataset(csv_file, query_processor) dataset.load() print("Dataset Loaded!") #process it into an efficient binary search tree print("Constructing RB-Tree...", end='', flush=True) binary_search_tree = RedBlackTree() temp_dataset = dataset.get_token_bookids() for i in range(len(temp_dataset)): binary_search_tree.insert(temp_dataset[i][0], temp_dataset[i][1]) print("RB-Tree constructed!") #create a BK-tree print("Constructing BK-Tree...", end='', flush=True) bk_tree = BKTree(levenshtein_distance_DP) english_dictionary_list = dataset.get_dictionary() for i in range(len(english_dictionary_list)): bk_tree.insert(english_dictionary_list[i]) print("BK-Tree constructed!") if CONSTRUCT_DEBUG_MODE: pickle.dump(dataset, open("dataset.p", 'wb')) pickle.dump(binary_search_tree, open("binary_search_tree.p", 'wb')) pickle.dump(bk_tree, open("bk_tree.p", 'wb')) #create a Ranker ranker = Ranker(dataset, query_processor, binary_search_tree, bk_tree) #Get the book-titles for “The search engine.” ranker.evaluate("The search engine.")
class PriorGenerator: def __init__(self, zero_to_alpha): #self.freq = {} #self.successor = {} self.word_successor = {} self.dictionary = set() self.stopSymbols = [] self.zero_to_alpha = zero_to_alpha self.bk_tree = BKTree() def get_word_successor(self, w1, w2): try: return self.word_successor[w1][w2] except: return self.zero_to_alpha def load_stop_symbols_from_file(self, path): with open(path) as file: for x in file.readlines(): self.stopSymbols.append(x[:-1]) def remove_stop_symbols(self, line): for x in self.stopSymbols: line = line.replace(x, " ") return line def analize_freq(self, path): with open(path) as file: for line in file.readlines(): line = line[:-1] line = self.remove_stop_symbols(line) #Solo lettere inglesi? line = line.lower() #Analizzo le frequenze # for i in range(len(line)): # char = line[i] # if (char > "z" or char < "a"): # continue # try: # self.freq[char] += 1 # except: # self.freq[char] = 1 # if i < len(line)-1: # try: # suc = self.successor[char] # try: # suc[line[i+1]] += 1 # except: # suc[line[i+1]] = 1 # except: # self.successor[char] = {} # self.successor[char][line[i+1]] = 1 #Se non si lavora solo con l'inglese bisogna considerare anche gli apostrofi e simili #Analizzo le parole line = line.split() for word_id in range(len(line)): self.dictionary.add(line[word_id]) if word_id < len(line) - 1: try: suc = self.word_successor[line[word_id]] try: suc[line[word_id + 1]] += 1 except: suc[line[word_id + 1]] = 1 except: self.word_successor[line[word_id]] = {} self.word_successor[line[word_id]][line[word_id + 1]] = 1 def finalize(self): #tot = sum(self.freq.values()) # for key in self.freq.keys(): # self.freq[key] /= tot # tot2 = sum(self.successor[key].values()) # for key2 in self.successor[key].keys(): # self.successor[key][key2] /= tot2 for key in self.word_successor.keys(): tot = sum(self.word_successor[key].values()) for key2 in self.word_successor[key].keys(): self.word_successor[key][key2] /= tot for key in self.dictionary: self.bk_tree.addWord(key) def serialize(self, path): tmp_dic = { "Dictionary": list(self.dictionary), "Word_successor": self.word_successor, "BK_Tree": self.bk_tree.root } with open(path, "w") as outFile: json.dump(tmp_dic, outFile) def deserialize(self, path): with open(path) as inFile: tmp_dic = json.load(inFile) self.dictionary = tmp_dic["Dictionary"] #self.freq = tmp_dic["Freq"] #self.successor = tmp_dic["Successor"] self.word_successor = tmp_dic["Word_successor"] self.bk_tree = BKTree() self.bk_tree.root = tmp_dic["BK_Tree"]
def main(argv=None): import os os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list try: os.makedirs(FLAGS.output_dir) except OSError as e: if e.errno != 17: raise bk_tree = BKTree(levenshtein, list_words(FLAGS.vocab)) # bk_tree = bktree.Tree() with tf.get_default_graph().as_default(): input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') input_feature_map = tf.placeholder(tf.float32, shape=[None, None, None, 32], name='input_feature_map') input_transform_matrix = tf.placeholder(tf.float32, shape=[None, 6], name='input_transform_matrix') input_box_mask = [] input_box_mask.append( tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0')) input_box_widths = tf.placeholder(tf.int32, shape=[None], name='input_box_widths') input_seq_len = input_box_widths[tf.argmax( input_box_widths, 0)] * tf.ones_like(input_box_widths) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) shared_feature, f_score, f_geometry = detect_part.model(input_images) pad_rois = roi_rotate_part.roi_rotate_tensor_pad( input_feature_map, input_transform_matrix, input_box_mask, input_box_widths) recognition_logits = recognize_part.build_graph( pad_rois, input_box_widths) _, dense_decode = recognize_part.decode(recognition_logits, input_box_widths) variable_averages = tf.train.ExponentialMovingAverage( 0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) model_path = os.path.join( FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) im_fn_list = get_images() for im_fn in im_fn_list: #im = cv2.imread(im_fn)[:, :, ::-1] im = cv2.imread(im_fn) im = cv2.resize(im, (960, 540)) start_time = time.time() im_resized, (ratio_h, ratio_w) = resize_image(im) # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im) timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0} start = time.time() shared_feature_map, score, geometry = sess.run( [shared_feature, f_score, f_geometry], feed_dict={input_images: [im_resized]}) boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) timer['detect'] = time.time() - start start = time.time() # reset for recognition if boxes is not None and boxes.shape[0] != 0: #res_file_path = os.path.join(FLAGS.output_dir,'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0])) res_file_path = os.path.join( FLAGS.output_dir, '{}.txt'.format(os.path.basename(im_fn))) input_roi_boxes = boxes[:, :8].reshape(-1, 8) recog_decode_list = [] # Here avoid too many text area leading to OOM for batch_index in range(input_roi_boxes.shape[0] // 32 + 1): # test roi batch size is 32 start_slice_index = batch_index * 32 end_slice_index = ( batch_index + 1 ) * 32 if input_roi_boxes.shape[0] >= ( batch_index + 1) * 32 else input_roi_boxes.shape[0] tmp_roi_boxes = input_roi_boxes[ start_slice_index:end_slice_index] boxes_masks = [0] * tmp_roi_boxes.shape[0] transform_matrixes, box_widths = get_project_matrix_and_width( tmp_roi_boxes) #max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len # Run end to end recog_decode = sess.run(dense_decode, feed_dict={ input_feature_map: shared_feature_map, input_transform_matrix: transform_matrixes, input_box_mask[0]: boxes_masks, input_box_widths: box_widths }) recog_decode_list.extend([r for r in recog_decode]) timer['recog'] = time.time() - start # Preparing for draw boxes boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h if len(recog_decode_list) != boxes.shape[0]: print( "detection and recognition result are not equal!") exit(-1) with open(res_file_path, 'w') as f: for i, box in enumerate(boxes): # to avoid submitting errors box = sort_poly(box.astype(np.int32)) if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm( box[3] - box[0]) < 5: continue recognition_result = ground_truth_to_word( recog_decode_list[i]) if contain_eng(recognition_result): print(recognition_result) fix_result = bktree_search( bk_tree, recognition_result.lower()) print(fix_result) if len(fix_result) != 0: recognition_result = fix_result[0][1] # print(recognition_result) else: recognition_result = recognition_result f.write('{},{},{},{},{},{},{},{},{}\r\n'.format( box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1], recognition_result)) # Draw bounding box cv2.polylines( im, [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) # Draw recognition results area text_area = box.copy() text_area[2, 1] = text_area[1, 1] text_area[3, 1] = text_area[0, 1] text_area[0, 1] = text_area[0, 1] - 15 text_area[1, 1] = text_area[1, 1] - 15 cv2.fillPoly(im, [ text_area.astype(np.int32).reshape((-1, 1, 2)) ], color=(255, 255, 0)) im_txt = cv2.putText(im, recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1) # 中文文字添加: # im_txt = cv2ImgAddText(im, recognition_result, box[0, 0], box[0, 1], (0, 0, 149), 20) else: #res_file = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0])) res_file = os.path.join( FLAGS.output_dir, '{}.txt'.format(os.path.basename(im_fn))) f = open(res_file, "w") im_txt = None f.close() print( '{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms' .format(im_fn, timer['detect'] * 1000, timer['restore'] * 1000, timer['nms'] * 1000, timer['recog'] * 1000)) duration = time.time() - start_time print('[timing] {}'.format(duration)) if not FLAGS.no_write_images: img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn)) #cv2.imwrite(img_path, im[:, :, ::-1]) if im_txt is not None: cv2.imwrite(img_path, im_txt)
import tensorflow as tf from helper import deprecated from bktree import BKTree NUM_SHIRTS = 69 SHIRT_MESSAGE = "Gimme shirt pls" TREE = BKTree() @deprecated def brute_force_shirt(num_shirts=NUM_SHIRTS, shirt_message=SHIRT_MESSAGE): shirt_requests = [] for i in range(num_shirts): shirt_requests.append(shirt_message) return shirt_requests def ml_shirt_acquiry(): shirt = tf.constant(SHIRT_MESSAGE) sess = tf.Session() return sess.run(shirt) def spell_check(word): with open('/usr/share/dict/words') as f: possible_words = [w.replace('\n', '') for w in f.readlines()] def word_path(word): return TREE.find_words(word)
import timeit from bktree import BKTree business_dictionary = [a.strip() for a in open('business-names.txt')] tree = BKTree(sanitize=True) tree.add(business_dictionary) setup = """ from bktree import BKTree business_dictionary = [a.strip() for a in open('business-names.txt')] tree = BKTree(sanitize=True) tree.add(business_dictionary) """ def test_word(word, radius): perf = timeit.timeit(f'tree.search("{word}", {radius})', number=100, setup=setup) print(f'Performance of tree.search("{word}", {radius}) = {perf}') print(tree.search(f"{word}", 1)) if __name__ == "__main__": for w, r in [ ('walmart', 1), ('walmartt', 1), ('walmarttt', 2), ('walllrt', 2),
for j in range(len(graphs[i]['edges'])): str1 = "" for vertex in sorted(graphs[i]['edges'][j]): str1 += "v" + str(vertex) features[i][str1] = qualityEdges[i][j] values = [build_by_features(features[i]) for i in range(l)] valuesRevDict = {} for i in range(len(values)): if values[i] in valuesRevDict: valuesRevDict[values[i]].append(i) else: valuesRevDict[values[i]] = [i] tree = BKTree() for value in values: tree.add(value) # for i in range(l): # for j in range(i+1,l): # score = computeScore(values[i], values[j]) # print str(graphs[i]["label"]) + " " + str(graphs[j]["label"]) + " " + str(score) for i in range(l): closest_pairs = tree.find(values[i], MAX_DISTANCE) final_pairs = [] for pair in closest_pairs: a, b = pair a = 1 - float(a) / F
alt_label = [ Label(root, textvar=alt_str[0], justify='left'), Label(root, textvar=alt_str[1], justify='left'), Label(root, textvar=alt_str[2], justify='left'), Label(root, textvar=alt_str[3], justify='left'), Label(root, textvar=alt_str[4], justify='left') ] for a_l in alt_label: a_l.pack(side="left") root.config(menu=menubar) # Finalmente bucle de la apliación root.bind("<space>", word_proc) texto.bind("<Button-1>", focused) texto.bind("<Left>", focused) texto.bind("<Right>", focused) texto.bind("<Up>", focused) texto.bind("<Down>", focused) alt_label[0].bind("<Button-1>", swap_alt0) alt_label[1].bind("<Button-1>", swap_alt1) alt_label[2].bind("<Button-1>", swap_alt2) alt_label[3].bind("<Button-1>", swap_alt3) alt_label[4].bind("<Button-1>", swap_alt4) BKT = BKTree(levenshtein, dict_words('aymara')) root.mainloop()
def main(argv=None): import os os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list try: os.makedirs(FLAGS.output_dir) except OSError as e: if e.errno != 17: raise if FLAGS.use_vacab and os.path.exists("./vocab.txt"): bk_tree = BKTree(levenshtein, dict_words('./vocab.txt')) with tf.get_default_graph().as_default(): input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') input_transform_matrix = tf.placeholder(tf.float32, shape=[None, 6], name='input_transform_matrix') # input_box_mask = tf.placeholder(tf.int32, shape=[None], name='input_box_mask') input_box_mask = [] input_box_mask.append( tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0')) input_box_widths = tf.placeholder(tf.int32, shape=[None], name='input_box_widths') # input_box_nums = tf.placeholder(tf.int32, name='input_box_nums') # input_seq_len = tf.placeholder(tf.int32, shape=[None], name='input_seq_len') input_seq_len = input_box_widths[tf.argmax( input_box_widths, 0)] * tf.ones_like(input_box_widths) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) shared_feature, f_score, f_geometry = detect_part.model(input_images) pad_rois = roi_rotate_part.roi_rotate_tensor_pad( shared_feature, input_transform_matrix, input_box_mask, input_box_widths) recognition_logits = recognize_part.build_graph( pad_rois, input_box_widths) _, dense_decode = recognize_part.decode(recognition_logits, input_box_widths) variable_averages = tf.train.ExponentialMovingAverage( 0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) model_path = os.path.join( FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) im_fn_list = get_images() for im_fn in im_fn_list: im = cv2.imread(im_fn)[:, :, ::-1] start_time = time.time() im_resized, (ratio_h, ratio_w) = resize_image(im) # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im) timer = {'net': 0, 'restore': 0, 'nms': 0} start = time.time() score, geometry = sess.run( [f_score, f_geometry], feed_dict={input_images: [im_resized]}) boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) """ if boxes is not None: boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h """ # save to file if boxes is not None and boxes.shape[0] != 0: res_file = os.path.join( FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0])) input_roi_boxes = boxes[:, :8].reshape(-1, 8) # input_roi_boxes = boxes[:, :8].reshape((-1, 4, 2)) # input_roi_boxes = boxes.copy() # input_roi_boxes[:, :, 0] *= ratio_w # input_roi_boxes[:, :, 1] *= ratio_h # input_roi_boxes = input_roi_boxes.reshape((-1, 8)) # boxes_masks = np.array([0] * input_roi_boxes.shape[0]) boxes_masks = [0] * input_roi_boxes.shape[0] transform_matrixes, box_widths = get_project_matrix_and_width( input_roi_boxes) # max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len # Run end to end recog_decode = sess.run(dense_decode, feed_dict={ input_images: [im_resized], input_transform_matrix: transform_matrixes, input_box_mask[0]: boxes_masks, input_box_widths: box_widths }) timer['net'] = time.time() - start # Preparing for draw boxes boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h # print "recognition result: " # for pred in recog_decode: # print ground_truth_to_word(pred) if recog_decode.shape[0] != boxes.shape[0]: print "detection and recognition result are not equal!" exit(-1) with open(res_file, 'w') as f: for i, box in enumerate(boxes): # to avoid submitting errors box = sort_poly(box.astype(np.int32)) if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm( box[3] - box[0]) < 5: continue recognition_result = ground_truth_to_word( recog_decode[i]) if FLAGS.use_vacab: fix_result = bktree_search( bk_tree, recognition_result.upper()) if len(fix_result) != 0: recognition_result = fix_result[0][1] f.write('{},{},{},{},{},{},{},{},{}\r\n'.format( box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1], recognition_result)) """ f.write('{},{},{},{},{},{},{},{}\r\n'.format( box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1] )) """ # Draw bounding box cv2.polylines( im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) # Draw recognition results area text_area = box.copy() text_area[2, 1] = text_area[1, 1] text_area[3, 1] = text_area[0, 1] text_area[0, 1] = text_area[0, 1] - 15 text_area[1, 1] = text_area[1, 1] - 15 cv2.fillPoly(im[:, :, ::-1], [ text_area.astype(np.int32).reshape((-1, 1, 2)) ], color=(255, 255, 0)) im_txt = cv2.putText(im[:, :, ::-1], recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1) else: timer['net'] = time.time() - start res_file = os.path.join( FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0])) f = open(res_file, "w") im_txt = None f.close() print( '{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format( im_fn, timer['net'] * 1000, timer['restore'] * 1000, timer['nms'] * 1000)) duration = time.time() - start_time print('[timing] {}'.format(duration)) if not FLAGS.no_write_images: img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn)) # cv2.imwrite(img_path, im[:, :, ::-1]) if im_txt is not None: cv2.imwrite(img_path, im_txt)
def main(argv=None): import os os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list try: os.makedirs(FLAGS.output_dir) except OSError as e: if e.errno != 17: raise bk_tree = BKTree(levenshtein, list_words(FLAGS.vocab)) # bk_tree = bktree.Tree() with tf.get_default_graph().as_default(): input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') input_feature_map = tf.placeholder(tf.float32, shape=[None, None, None, 32], name='input_feature_map') input_transform_matrix = tf.placeholder(tf.float32, shape=[None, 6], name='input_transform_matrix') input_box_mask = [] input_box_mask.append( tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0')) input_box_widths = tf.placeholder(tf.int32, shape=[None], name='input_box_widths') input_seq_len = input_box_widths[tf.argmax( input_box_widths, 0)] * tf.ones_like(input_box_widths) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) shared_feature, f_score, f_geometry = detect_part.model(input_images) pad_rois = roi_rotate_part.roi_rotate_tensor_pad( input_feature_map, input_transform_matrix, input_box_mask, input_box_widths) recognition_logits = recognize_part.build_graph( pad_rois, input_box_widths) _, dense_decode = recognize_part.decode(recognition_logits, input_box_widths) variable_averages = tf.train.ExponentialMovingAverage( 0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) model_path = os.path.join( FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) # im_fn_list = get_images() if FLAGS.just_infer: im_fn_list, _, _ = get_image_self( "/data/ceph_11015/ssd/anhan/nba/video2image") else: im_fn_list, corridate_list, label_list = get_image_self( "/data/ceph_11015/ssd/anhan/nba/video2image") wrong = 0 total = 0 for ind, im_fn in enumerate(im_fn_list): #print("im_fn:",im_fn) im = cv2.imread(im_fn)[:, :, ::-1] im = cv2.resize(im, (960, 540)) start_time = time.time() im_resized, (ratio_h, ratio_w) = resize_image(im) # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im) timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0} start = time.time() shared_feature_map, score, geometry = sess.run( [shared_feature, f_score, f_geometry], feed_dict={input_images: [im_resized]}) boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) timer['detect'] = time.time() - start start = time.time() # reset for recognition res = None str_list = [] if boxes is not None and boxes.shape[0] != 0: #res_file_path = os.path.join(FLAGS.output_dir,'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0])) # res_file_path = os.path.join(FLAGS.output_dir, '{}.txt'.format(os.path.basename(im_fn))) input_roi_boxes = boxes[:, :8].reshape(-1, 8) recog_decode_list = [] # Here avoid too many text area leading to OOM for batch_index in range(input_roi_boxes.shape[0] // 32 + 1): # test roi batch size is 32 start_slice_index = batch_index * 32 end_slice_index = ( batch_index + 1 ) * 32 if input_roi_boxes.shape[0] >= ( batch_index + 1) * 32 else input_roi_boxes.shape[0] tmp_roi_boxes = input_roi_boxes[ start_slice_index:end_slice_index] boxes_masks = [0] * tmp_roi_boxes.shape[0] transform_matrixes, box_widths = get_project_matrix_and_width( tmp_roi_boxes) #max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len # Run end to end recog_decode = sess.run(dense_decode, feed_dict={ input_feature_map: shared_feature_map, input_transform_matrix: transform_matrixes, input_box_mask[0]: boxes_masks, input_box_widths: box_widths }) recog_decode_list.extend([r for r in recog_decode]) timer['recog'] = time.time() - start # Preparing for draw boxes boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h if len(recog_decode_list) != boxes.shape[0]: print( "detection and recognition result are not equal!") exit(-1) scores = {} score_index = 0 time_left = {} time_index = 0 team_name = {} quarter_dict = {} remainder_attack_time = {} remainder_attack_time_index = 0 recognition_result_num = 0 points = {} for i, box in enumerate(boxes): # to avoid submitting errors box = sort_poly(box.astype(np.int32)) if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm( box[3] - box[0]) < 5: continue recognition_result = ground_truth_to_word( recog_decode_list[i]) if contain_eng(recognition_result): #print(recognition_result) fix_result = bktree_search( bk_tree, recognition_result.lower()) #print(fix_result) if len(fix_result) != 0: recognition_result = fix_result[0][1] #print(recognition_result) else: recognition_result = recognition_result if recognition_result in all_team: team_name[recognition_result] = [ (int(box[0, 0]) + int(box[2, 0])) / 2, (int(box[0, 1]) + int(box[2, 1])) / 2 ] points[recognition_result] = [box[0, 0], box[2, 0]] if recognition_result in quarter: quarter_dict[recognition_result] = [ (int(box[0, 0]) + int(box[2, 0])) / 2, (int(box[0, 1]) + int(box[2, 1])) / 2 ] points[recognition_result] = [box[0, 0], box[2, 0]] if recognition_result.isdigit(): scores[recognition_result + "_" + str(score_index)] = [ (int(box[0, 0]) + int(box[2, 0])) / 2, (int(box[0, 1]) + int(box[2, 1])) / 2 ] points[recognition_result + "_" + str(score_index)] = [box[0, 0], box[2, 0]] score_index += 1 if ":" in recognition_result: time_left[recognition_result + "_" + str(time_index)] = [ (int(box[0, 0]) + int(box[2, 0])) / 2, (int(box[0, 1]) + int(box[2, 1])) / 2 ] points[recognition_result + "_" + str(time_index)] = [box[0, 0], box[2, 0]] time_index += 1 if "." in recognition_result and ":" not in recognition_result: remainder_attack_time[ recognition_result + "_" + str(remainder_attack_time_index)] = [ (int(box[0, 0]) + int(box[2, 0])) / 2, (int(box[0, 1]) + int(box[2, 1])) / 2 ] points[recognition_result + "_" + str(remainder_attack_time_index)] = [ box[0, 0], box[2, 0] ] remainder_attack_time_index += 1 recognition_result_num += 1 str_list.append(recognition_result) # Draw bounding box # cv2.polylines(im, [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) # Draw recognition results area # text_area = box.copy() # text_area[2, 1] = text_area[1, 1] # text_area[3, 1] = text_area[0, 1] # text_area[0, 1] = text_area[0, 1] - 15 # text_area[1, 1] = text_area[1, 1] - 15 # cv2.fillPoly(im, [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0)) # im_txt = cv2.putText(im, recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1) # 中文文字添加: # im_txt = cv2ImgAddText(im, recognition_result, box[0, 0], box[0, 1], (0, 0, 149), 20) if recognition_result_num == 7 or recognition_result_num == 6 or recognition_result_num == 5 or recognition_result_num == 8: res = get_content(remainder_attack_time, time_left, team_name, scores, quarter_dict) elif recognition_result_num == 9: sort_points = sorted(points.items(), key=lambda item: item[1][0]) x_coordiate = [] for pair in sort_points: x_coordiate.append(pair[1][0]) x_coordiate.append(pair[1][1]) x_sort = sorted(x_coordiate) if x_sort == x_coordiate: drop1 = sort_points[1][0] drop2 = sort_points[4][0] if drop1 in remainder_attack_time: remainder_attack_time = remove_key( remainder_attack_time, drop1) if drop2 in remainder_attack_time: remainder_attack_time = remove_key( remainder_attack_time, drop2) if drop1 in time_left: time_left = remove_key(time_left, drop1) if drop2 in time_left: time_left = remove_key(time_left, drop2) if drop1 in scores: scores = remove_key(scores, drop1) if drop2 in scores: scores = remove_key(scores, drop2) res = get_content(remainder_attack_time, time_left, team_name, scores, quarter_dict) if not FLAGS.just_infer: corridate_true = corridate_list[ind].split("_")[4:] label_true = label_list[ind].split("_") res_true = get_score_info_v2(corridate_true, label_true) if res != res_true: #print(im_fn.split("/")[-1],'wrong!!!') wrong += 1 #print(im_fn.split("/")[-1],label_list[ind],res_true,res,("_").join(str_list)) total += 1 print( im_fn.split("/")[-1], label_list[ind], res_true, res, ("_").join(str_list)) else: print(im_fn.split("/")[-1], res, ("_").join(str_list)) duration = time.time() - start_time #print('{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms'.format(im_fn, timer['detect']*1000, timer['restore']*1000, timer['nms']*1000, timer['recog']*1000)) print("wrong:{}".format(wrong)) print("total:{}".format(total)) print("precision:{}".format((total - wrong) / total))