Exemplo n.º 1
0
def test_bktree():
    test_passed = True

    test_database = pickle.load( open( 'testdatabase.p', 'rb' ) )
    test_database = test_database['bktree']

    keys = test_database['keys']
    result_target = test_database['result']
    query = test_database['query']

    BKT = BKTree(levenshtein_distance_DP)
    for key in keys:
        BKT.insert(key)

    result = BKT.get(query)

    for (word,distance) in result:
        encounter = 0
        for (word_target,distance_target) in result_target:
            if word == word_target:
                if not distance == distance_target:
                    test_passed = False
                encounter +=1
        if not encounter == 1:
            test_passed = False

    return test_passed
 def __init__(self, zero_to_alpha):
     #self.freq = {}
     #self.successor = {}
     self.word_successor = {}
     self.dictionary = set()
     self.stopSymbols = []
     self.zero_to_alpha = zero_to_alpha
     self.bk_tree = BKTree()
 def deserialize(self, path):
     with open(path) as inFile:
         tmp_dic = json.load(inFile)
         self.dictionary = tmp_dic["Dictionary"]
         #self.freq = tmp_dic["Freq"]
         #self.successor = tmp_dic["Successor"]
         self.word_successor = tmp_dic["Word_successor"]
         self.bk_tree = BKTree()
         self.bk_tree.root = tmp_dic["BK_Tree"]
def test_bk_nearest_neighbor_search():
    """
    Test BK_Nearest_Neighbor_Search function.
    """
    # --- Preparations

    string_list = [
        'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
        'ten'
    ]
    tree = BKTree(string_list)

    # --- Exercise functionality
    search1 = bk_nearest_neighbor_search('eight', tree)
    search2 = bk_nearest_neighbor_search('ter', tree)
    search3 = bk_nearest_neighbor_search('123456789', tree)
    search4 = bk_nearest_neighbor_search('ffff', tree)

    # --- Check results
    assert len(search1) == 2
    assert search1[0] == 0
    assert search1[1:] == ['eight']
    assert len(search2) == 2
    assert search2[0] == 1
    assert search2[1:] == ['ten']
    assert len(search3) == 11
    assert search3[0] == 9
    for value in search3[1:]:
        assert value in string_list
    assert len(search4) == 3
    assert search4[0] == 3
    for value in search4[1:]:
        assert value in ['five', 'four']
Exemplo n.º 5
0
def main():
    
    og_ids = [
        "1", 
        "2", 
        "18", # Cellosaurus with relavent terms for human biology 
        "5", 
        "7", 
        "9",
        "19"
    ]
    ogs = [load_ontology.load(x)[0] for x in og_ids]
    str_to_terms = defaultdict(lambda: [])

    print("Gathering all term string identifiers in ontologies...")
    string_identifiers = set()
    for og in ogs:
        for id, term in og.id_to_term.items():
            str_to_terms[term.name].append([term.id, "TERM_NAME"])
            string_identifiers.add(term.name)
            for syn in term.synonyms:
                str_to_terms[syn.syn_str].append([term.id, "SYNONYM_%s" % syn.syn_type])
                string_identifiers.add(syn.syn_str)

    print("Building the BK-Tree...")
    bk_tree = BKTree(string_metrics.bag_dist_multiset, string_identifiers)

    # with open("fuzzy_match_bk_tree.pickle", "w") as f:
    with open("fuzzy_match_bk_tree.pickle", "wb") as f:
        pickle.dump(bk_tree, f)

    with open("fuzzy_match_string_data.json", "w") as f:
        f.write(json.dumps(str_to_terms, indent=4, separators=(',', ': ')))
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description="Build BK-tree for fuzzy matching.\n"
        "Example of ID specification:\n\thuman: 1,2,18,5,7,9,19\n"
        "\tarabidopsis: 20,21")
    parser.add_argument("-i",
                        "--ids",
                        required=True,
                        help="comma separated list of ontology IDs.")
    parser.add_argument(
        "-j",
        "--json_filename",
        required=True,
        help="filename of output json."
        "'fuzzy_match_string_data_SPECIES.json' is recommended.")
    parser.add_argument("-p",
                        "--pickle_filename",
                        required=True,
                        help="filename of output pickle"
                        "'fuzzy_match_bk_tree_SPECIES.pickle' is recommended.")
    args = parser.parse_args()

    og_ids = args.ids.split(",")
    ogs = [load_ontology.load(x)[0] for x in og_ids]
    str_to_terms = defaultdict(lambda: [])

    print("Gathering all term string identifiers in ontologies...")
    string_identifiers = set()
    for og in ogs:
        for id, term in og.id_to_term.items():
            str_to_terms[term.name].append([term.id, "TERM_NAME"])
            string_identifiers.add(term.name)
            for syn in term.synonyms:
                str_to_terms[syn.syn_str].append(
                    [term.id, "SYNONYM_%s" % syn.syn_type])
                string_identifiers.add(syn.syn_str)

    print("Building the BK-Tree...")
    bk_tree = BKTree(string_metrics.bag_dist_multiset, string_identifiers)

    with open(args.pickle_filename, "wb") as f:
        pickle.dump(bk_tree, f)

    with open(args.json_filename, "w") as f:
        f.write(json.dumps(str_to_terms, indent=4, separators=(',', ': ')))
Exemplo n.º 7
0
    def __init__(self, proxy_map):
        super(SpecificWorker, self).__init__(proxy_map)
        self.timer.timeout.connect(self.compute)
        self.Period = 200
        self.timer.start(self.Period)

        # load the pre-trained EAST text detector
        print "[INFO] loading EAST text detector..."
        self.net = cv2.dnn.readNet(NET_FILE)
        self.model = crnn.CRNN(32, 1, 37, 256)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        print "[INFO] loading CRNN text recognizer..."
        self.model.load_state_dict(torch.load(MODEL_FILE))
        self.tree = None
        if self.use_lexicon:
            print "[INFO] loading generic english lexicon..."
            lexicon = []
            with open(LEXICON_FILE) as f:
                for line in f.read().splitlines():
                    lexicon.append(line.lower())
            print "Length of the lexicon: ", len(lexicon)
            self.tree = BKTree(hamming_distance, lexicon)
Exemplo n.º 8
0
    def __init__(
            self,
            filepath_corpus,
            filepath_bk_tree,
            filepath_token_cnts,
            vocab_path,
            ngram_range=(2, 3),
            token_pattern=r"(?u)\b\w\w+\b",
            preprocessor=None,
            tokenizer=None,
            lowercase=True,
            stop_words=None  #list of stopwords   #keep it to None
    ):
        self.filepath_corpus = filepath_corpus
        self.filepath_bk_tree = filepath_bk_tree
        self.vocab_path = vocab_path
        self.ngram_range = ngram_range
        self.stop_words = stop_words
        self.lowercase = lowercase
        self.preprocessor = preprocessor
        self.tokenizer = tokenizer
        self.token_pattern = token_pattern

        with open(vocab_path, 'rb') as f:
            word2idx, idx2word = pickle.load(f)

        self.vocab_size = len(word2idx)

        try:
            with open(self.filepath_bk_tree, 'rb') as f:
                self.tree = pickle.load(f)

        except Exception as e:
            pickle_dump(BKTree(items_dict=word2idx), self.filepath_bk_tree)

            with open(self.filepath_bk_tree, 'rb') as f:
                self.tree = pickle.load(f)

        self.token_cnts = None
        self.bi_pre_token_cnts = None
        self.bi_post_token_cnts = None
        self.tri_token_cnts = None

        self.filepath_token_cnts = None
        self.filepath_bi_pre_token_cnts = None
        self.filepath_bi_post_token_cnts = None
        self.filepath_tri_token_cnts = None

        if filepath_token_cnts is not None and os.path.exists(
                filepath_token_cnts):
            print('LOADING...')
            with io.open(filepath_token_cnts, encoding='utf-8') as f:
                self.token_cnts = json.load(f)

            with io.open('bi_pre_' + filepath_token_cnts,
                         encoding='utf-8') as f:
                self.bi_pre_token_cnts = json.load(f)

            with io.open('bi_post_' + filepath_token_cnts,
                         encoding='utf-8') as f:
                self.bi_post_token_cnts = json.load(f)

            with io.open('tri_' + filepath_token_cnts, encoding='utf-8') as f:
                self.tri_token_cnts = json.load(f)

        else:
            print('GETTING FROM CORPUS...')
            self.token_cnts = defaultdict()
            self.bi_pre_token_cnts = defaultdict()
            self.bi_post_token_cnts = defaultdict()
            self.tri_token_cnts = defaultdict()

            self.filepath_token_cnts = filepath_token_cnts
            self.filepath_bi_pre_token_cnts = 'bi_pre_' + filepath_token_cnts
            self.filepath_bi_post_token_cnts = 'bi_post_' + filepath_token_cnts
            self.filepath_tri_token_cnts = 'tri_' + filepath_token_cnts

            self.build_ngrams_count_from_corpus()
Exemplo n.º 9
0
    print("Constructing dataset........", end='', flush=True)
    dataset = Dataset(csv_file, query_processor)
    dataset.load()
    print("Dataset Loaded!")

    #process it into an efficient binary search tree
    print("Constructing RB-Tree...", end='', flush=True)
    binary_search_tree = RedBlackTree()
    temp_dataset = dataset.get_token_bookids()
    for i in range(len(temp_dataset)):
        binary_search_tree.insert(temp_dataset[i][0], temp_dataset[i][1])
    print("RB-Tree constructed!")

    #create a BK-tree
    print("Constructing BK-Tree...", end='', flush=True)
    bk_tree = BKTree(levenshtein_distance_DP)
    english_dictionary_list = dataset.get_dictionary()
    for i in range(len(english_dictionary_list)):
        bk_tree.insert(english_dictionary_list[i])
    print("BK-Tree constructed!")

    if CONSTRUCT_DEBUG_MODE:
        pickle.dump(dataset, open("dataset.p", 'wb'))
        pickle.dump(binary_search_tree, open("binary_search_tree.p", 'wb'))
        pickle.dump(bk_tree, open("bk_tree.p", 'wb'))

#create a Ranker
ranker = Ranker(dataset, query_processor, binary_search_tree, bk_tree)

#Get the book-titles for “The search engine.”
ranker.evaluate("The search engine.")
Exemplo n.º 10
0
class PriorGenerator:
    def __init__(self, zero_to_alpha):
        #self.freq = {}
        #self.successor = {}
        self.word_successor = {}
        self.dictionary = set()
        self.stopSymbols = []
        self.zero_to_alpha = zero_to_alpha
        self.bk_tree = BKTree()

    def get_word_successor(self, w1, w2):
        try:
            return self.word_successor[w1][w2]
        except:
            return self.zero_to_alpha

    def load_stop_symbols_from_file(self, path):
        with open(path) as file:
            for x in file.readlines():
                self.stopSymbols.append(x[:-1])

    def remove_stop_symbols(self, line):
        for x in self.stopSymbols:
            line = line.replace(x, " ")
        return line

    def analize_freq(self, path):
        with open(path) as file:
            for line in file.readlines():
                line = line[:-1]
                line = self.remove_stop_symbols(line)
                #Solo lettere inglesi?
                line = line.lower()

                #Analizzo le frequenze
                # for i in range(len(line)):
                #     char = line[i]
                #     if (char > "z" or char < "a"):
                #         continue
                #     try:
                #         self.freq[char] += 1
                #     except:
                #         self.freq[char] = 1
                #     if i < len(line)-1:
                #         try:
                #             suc = self.successor[char]
                #             try:
                #                 suc[line[i+1]] += 1
                #             except:
                #                 suc[line[i+1]] = 1
                #         except:
                #             self.successor[char] = {}
                #             self.successor[char][line[i+1]] = 1
                #Se non si lavora solo con l'inglese bisogna considerare anche gli apostrofi e simili

                #Analizzo le parole
                line = line.split()
                for word_id in range(len(line)):
                    self.dictionary.add(line[word_id])
                    if word_id < len(line) - 1:
                        try:
                            suc = self.word_successor[line[word_id]]
                            try:
                                suc[line[word_id + 1]] += 1
                            except:
                                suc[line[word_id + 1]] = 1
                        except:
                            self.word_successor[line[word_id]] = {}
                            self.word_successor[line[word_id]][line[word_id +
                                                                    1]] = 1

    def finalize(self):
        #tot = sum(self.freq.values())
        # for key in self.freq.keys():
        #     self.freq[key] /= tot
        #     tot2 = sum(self.successor[key].values())
        #     for key2 in self.successor[key].keys():
        #         self.successor[key][key2] /= tot2
        for key in self.word_successor.keys():
            tot = sum(self.word_successor[key].values())
            for key2 in self.word_successor[key].keys():
                self.word_successor[key][key2] /= tot
        for key in self.dictionary:
            self.bk_tree.addWord(key)

    def serialize(self, path):
        tmp_dic = {
            "Dictionary": list(self.dictionary),
            "Word_successor": self.word_successor,
            "BK_Tree": self.bk_tree.root
        }
        with open(path, "w") as outFile:
            json.dump(tmp_dic, outFile)

    def deserialize(self, path):
        with open(path) as inFile:
            tmp_dic = json.load(inFile)
            self.dictionary = tmp_dic["Dictionary"]
            #self.freq = tmp_dic["Freq"]
            #self.successor = tmp_dic["Successor"]
            self.word_successor = tmp_dic["Word_successor"]
            self.bk_tree = BKTree()
            self.bk_tree.root = tmp_dic["BK_Tree"]
Exemplo n.º 11
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    bk_tree = BKTree(levenshtein, list_words(FLAGS.vocab))
    # bk_tree = bktree.Tree()

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_feature_map = tf.placeholder(tf.float32,
                                           shape=[None, None, None, 32],
                                           name='input_feature_map')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')

        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_score, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            input_feature_map, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list = get_images()
            for im_fn in im_fn_list:
                #im = cv2.imread(im_fn)[:, :, ::-1]
                im = cv2.imread(im_fn)
                im = cv2.resize(im, (960, 540))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im)

                timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0}
                start = time.time()
                shared_feature_map, score, geometry = sess.run(
                    [shared_feature, f_score, f_geometry],
                    feed_dict={input_images: [im_resized]})

                boxes, timer = detect(score_map=score,
                                      geo_map=geometry,
                                      timer=timer)
                timer['detect'] = time.time() - start
                start = time.time()  # reset for recognition
                if boxes is not None and boxes.shape[0] != 0:
                    #res_file_path = os.path.join(FLAGS.output_dir,'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    res_file_path = os.path.join(
                        FLAGS.output_dir,
                        '{}.txt'.format(os.path.basename(im_fn)))

                    input_roi_boxes = boxes[:, :8].reshape(-1, 8)
                    recog_decode_list = []
                    # Here avoid too many text area leading to OOM
                    for batch_index in range(input_roi_boxes.shape[0] // 32 +
                                             1):  # test roi batch size is 32
                        start_slice_index = batch_index * 32
                        end_slice_index = (
                            batch_index + 1
                        ) * 32 if input_roi_boxes.shape[0] >= (
                            batch_index + 1) * 32 else input_roi_boxes.shape[0]
                        tmp_roi_boxes = input_roi_boxes[
                            start_slice_index:end_slice_index]

                        boxes_masks = [0] * tmp_roi_boxes.shape[0]
                        transform_matrixes, box_widths = get_project_matrix_and_width(
                            tmp_roi_boxes)
                        #max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                        # Run end to end
                        recog_decode = sess.run(dense_decode,
                                                feed_dict={
                                                    input_feature_map:
                                                    shared_feature_map,
                                                    input_transform_matrix:
                                                    transform_matrixes,
                                                    input_box_mask[0]:
                                                    boxes_masks,
                                                    input_box_widths:
                                                    box_widths
                                                })
                        recog_decode_list.extend([r for r in recog_decode])

                    timer['recog'] = time.time() - start
                    # Preparing for draw boxes
                    boxes = boxes[:, :8].reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h

                    if len(recog_decode_list) != boxes.shape[0]:
                        print(
                            "detection and recognition result are not equal!")
                        exit(-1)

                    with open(res_file_path, 'w') as f:
                        for i, box in enumerate(boxes):
                            # to avoid submitting errors
                            box = sort_poly(box.astype(np.int32))
                            if np.linalg.norm(box[0] -
                                              box[1]) < 5 or np.linalg.norm(
                                                  box[3] - box[0]) < 5:
                                continue
                            recognition_result = ground_truth_to_word(
                                recog_decode_list[i])

                            if contain_eng(recognition_result):
                                print(recognition_result)
                                fix_result = bktree_search(
                                    bk_tree, recognition_result.lower())
                                print(fix_result)
                                if len(fix_result) != 0:
                                    recognition_result = fix_result[0][1]
                                # print(recognition_result)
                            else:
                                recognition_result = recognition_result

                            f.write('{},{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1],
                                box[2, 0], box[2, 1], box[3, 0], box[3, 1],
                                recognition_result))

                            # Draw bounding box
                            cv2.polylines(
                                im, [box.astype(np.int32).reshape((-1, 1, 2))],
                                True,
                                color=(255, 255, 0),
                                thickness=1)
                            # Draw recognition results area
                            text_area = box.copy()
                            text_area[2, 1] = text_area[1, 1]
                            text_area[3, 1] = text_area[0, 1]
                            text_area[0, 1] = text_area[0, 1] - 15
                            text_area[1, 1] = text_area[1, 1] - 15
                            cv2.fillPoly(im, [
                                text_area.astype(np.int32).reshape((-1, 1, 2))
                            ],
                                         color=(255, 255, 0))
                            im_txt = cv2.putText(im, recognition_result,
                                                 (box[0, 0], box[0, 1]), font,
                                                 0.5, (0, 0, 255), 1)
                            # 中文文字添加:
                            # im_txt = cv2ImgAddText(im, recognition_result, box[0, 0], box[0, 1], (0, 0, 149), 20)
                else:
                    #res_file = os.path.join(FLAGS.output_dir, 'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    res_file = os.path.join(
                        FLAGS.output_dir,
                        '{}.txt'.format(os.path.basename(im_fn)))
                    f = open(res_file, "w")
                    im_txt = None
                    f.close()

                print(
                    '{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms'
                    .format(im_fn, timer['detect'] * 1000,
                            timer['restore'] * 1000, timer['nms'] * 1000,
                            timer['recog'] * 1000))

                duration = time.time() - start_time
                print('[timing] {}'.format(duration))

                if not FLAGS.no_write_images:
                    img_path = os.path.join(FLAGS.output_dir,
                                            os.path.basename(im_fn))
                    #cv2.imwrite(img_path, im[:, :, ::-1])
                    if im_txt is not None:
                        cv2.imwrite(img_path, im_txt)
Exemplo n.º 12
0
import tensorflow as tf
from helper import deprecated
from bktree import BKTree

NUM_SHIRTS = 69
SHIRT_MESSAGE = "Gimme shirt pls"
TREE = BKTree()


@deprecated
def brute_force_shirt(num_shirts=NUM_SHIRTS, shirt_message=SHIRT_MESSAGE):
    shirt_requests = []
    for i in range(num_shirts):
        shirt_requests.append(shirt_message)

    return shirt_requests


def ml_shirt_acquiry():
    shirt = tf.constant(SHIRT_MESSAGE)
    sess = tf.Session()
    return sess.run(shirt)


def spell_check(word):
    with open('/usr/share/dict/words') as f:
        possible_words = [w.replace('\n', '') for w in f.readlines()]


def word_path(word):
    return TREE.find_words(word)
Exemplo n.º 13
0
import timeit

from bktree import BKTree

business_dictionary = [a.strip() for a in open('business-names.txt')]
tree = BKTree(sanitize=True)
tree.add(business_dictionary)

setup = """
from bktree import BKTree

business_dictionary = [a.strip() for a in open('business-names.txt')]
tree = BKTree(sanitize=True)
tree.add(business_dictionary)
"""


def test_word(word, radius):
    perf = timeit.timeit(f'tree.search("{word}", {radius})',
                         number=100,
                         setup=setup)
    print(f'Performance of tree.search("{word}", {radius}) = {perf}')
    print(tree.search(f"{word}", 1))


if __name__ == "__main__":
    for w, r in [
        ('walmart', 1),
        ('walmartt', 1),
        ('walmarttt', 2),
        ('walllrt', 2),
Exemplo n.º 14
0
        for j in range(len(graphs[i]['edges'])):
            str1 = ""
            for vertex in sorted(graphs[i]['edges'][j]):
                str1 += "v" + str(vertex)
            features[i][str1] = qualityEdges[i][j]

    values = [build_by_features(features[i]) for i in range(l)]
    valuesRevDict = {}

    for i in range(len(values)):
        if values[i] in valuesRevDict:
            valuesRevDict[values[i]].append(i)
        else:
            valuesRevDict[values[i]] = [i]

    tree = BKTree()

    for value in values:
        tree.add(value)

    # for i in range(l):
    #     for j in range(i+1,l):
    #         score = computeScore(values[i], values[j])
    #         print str(graphs[i]["label"]) + " " + str(graphs[j]["label"]) + " " + str(score)

    for i in range(l):
        closest_pairs = tree.find(values[i], MAX_DISTANCE)
        final_pairs = []
        for pair in closest_pairs:
            a, b = pair
            a = 1 - float(a) / F
Exemplo n.º 15
0
alt_label = [
    Label(root, textvar=alt_str[0], justify='left'),
    Label(root, textvar=alt_str[1], justify='left'),
    Label(root, textvar=alt_str[2], justify='left'),
    Label(root, textvar=alt_str[3], justify='left'),
    Label(root, textvar=alt_str[4], justify='left')
]

for a_l in alt_label:
    a_l.pack(side="left")

root.config(menu=menubar)
# Finalmente bucle de la apliación

root.bind("<space>", word_proc)
texto.bind("<Button-1>", focused)
texto.bind("<Left>", focused)
texto.bind("<Right>", focused)
texto.bind("<Up>", focused)
texto.bind("<Down>", focused)
alt_label[0].bind("<Button-1>", swap_alt0)
alt_label[1].bind("<Button-1>", swap_alt1)
alt_label[2].bind("<Button-1>", swap_alt2)
alt_label[3].bind("<Button-1>", swap_alt3)
alt_label[4].bind("<Button-1>", swap_alt4)

BKT = BKTree(levenshtein, dict_words('aymara'))

root.mainloop()
Exemplo n.º 16
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    if FLAGS.use_vacab and os.path.exists("./vocab.txt"):
        bk_tree = BKTree(levenshtein, dict_words('./vocab.txt'))

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        # input_box_mask = tf.placeholder(tf.int32, shape=[None], name='input_box_mask')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')
        # input_box_nums = tf.placeholder(tf.int32, name='input_box_nums')
        # input_seq_len = tf.placeholder(tf.int32, shape=[None], name='input_seq_len')
        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_score, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            shared_feature, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list = get_images()
            for im_fn in im_fn_list:
                im = cv2.imread(im_fn)[:, :, ::-1]
                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im)

                timer = {'net': 0, 'restore': 0, 'nms': 0}
                start = time.time()
                score, geometry = sess.run(
                    [f_score, f_geometry],
                    feed_dict={input_images: [im_resized]})

                boxes, timer = detect(score_map=score,
                                      geo_map=geometry,
                                      timer=timer)
                """
                if boxes is not None:
                    boxes = boxes[:, :8].reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h
                """
                # save to file
                if boxes is not None and boxes.shape[0] != 0:
                    res_file = os.path.join(
                        FLAGS.output_dir, 'res_' +
                        '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))

                    input_roi_boxes = boxes[:, :8].reshape(-1, 8)
                    # input_roi_boxes = boxes[:, :8].reshape((-1, 4, 2))

                    # input_roi_boxes = boxes.copy()
                    # input_roi_boxes[:, :, 0] *= ratio_w
                    # input_roi_boxes[:, :, 1] *= ratio_h
                    # input_roi_boxes = input_roi_boxes.reshape((-1, 8))
                    # boxes_masks = np.array([0] * input_roi_boxes.shape[0])
                    boxes_masks = [0] * input_roi_boxes.shape[0]
                    transform_matrixes, box_widths = get_project_matrix_and_width(
                        input_roi_boxes)
                    # max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                    # Run end to end
                    recog_decode = sess.run(dense_decode,
                                            feed_dict={
                                                input_images: [im_resized],
                                                input_transform_matrix:
                                                transform_matrixes,
                                                input_box_mask[0]: boxes_masks,
                                                input_box_widths: box_widths
                                            })
                    timer['net'] = time.time() - start

                    # Preparing for draw boxes
                    boxes = boxes[:, :8].reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h

                    # print "recognition result: "
                    # for pred in recog_decode:
                    # print ground_truth_to_word(pred)
                    if recog_decode.shape[0] != boxes.shape[0]:
                        print "detection and recognition result are not equal!"
                        exit(-1)

                    with open(res_file, 'w') as f:
                        for i, box in enumerate(boxes):
                            # to avoid submitting errors
                            box = sort_poly(box.astype(np.int32))
                            if np.linalg.norm(box[0] -
                                              box[1]) < 5 or np.linalg.norm(
                                                  box[3] - box[0]) < 5:
                                continue
                            recognition_result = ground_truth_to_word(
                                recog_decode[i])
                            if FLAGS.use_vacab:
                                fix_result = bktree_search(
                                    bk_tree, recognition_result.upper())
                                if len(fix_result) != 0:
                                    recognition_result = fix_result[0][1]

                            f.write('{},{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1],
                                box[2, 0], box[2, 1], box[3, 0], box[3, 1],
                                recognition_result))
                            """
                            f.write('{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1]
                            ))
			    """
                            # Draw bounding box
                            cv2.polylines(
                                im[:, :, ::-1],
                                [box.astype(np.int32).reshape((-1, 1, 2))],
                                True,
                                color=(255, 255, 0),
                                thickness=1)
                            # Draw recognition results area
                            text_area = box.copy()
                            text_area[2, 1] = text_area[1, 1]
                            text_area[3, 1] = text_area[0, 1]
                            text_area[0, 1] = text_area[0, 1] - 15
                            text_area[1, 1] = text_area[1, 1] - 15
                            cv2.fillPoly(im[:, :, ::-1], [
                                text_area.astype(np.int32).reshape((-1, 1, 2))
                            ],
                                         color=(255, 255, 0))
                            im_txt = cv2.putText(im[:, :, ::-1],
                                                 recognition_result,
                                                 (box[0, 0], box[0, 1]), font,
                                                 0.5, (0, 0, 255), 1)
                else:
                    timer['net'] = time.time() - start
                    res_file = os.path.join(
                        FLAGS.output_dir, 'res_' +
                        '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    f = open(res_file, "w")
                    im_txt = None
                    f.close()

                print(
                    '{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format(
                        im_fn, timer['net'] * 1000, timer['restore'] * 1000,
                        timer['nms'] * 1000))

                duration = time.time() - start_time
                print('[timing] {}'.format(duration))

                if not FLAGS.no_write_images:
                    img_path = os.path.join(FLAGS.output_dir,
                                            os.path.basename(im_fn))
                    # cv2.imwrite(img_path, im[:, :, ::-1])
                    if im_txt is not None:
                        cv2.imwrite(img_path, im_txt)
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    bk_tree = BKTree(levenshtein, list_words(FLAGS.vocab))
    # bk_tree = bktree.Tree()

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        input_feature_map = tf.placeholder(tf.float32,
                                           shape=[None, None, None, 32],
                                           name='input_feature_map')
        input_transform_matrix = tf.placeholder(tf.float32,
                                                shape=[None, 6],
                                                name='input_transform_matrix')
        input_box_mask = []
        input_box_mask.append(
            tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
        input_box_widths = tf.placeholder(tf.int32,
                                          shape=[None],
                                          name='input_box_widths')

        input_seq_len = input_box_widths[tf.argmax(
            input_box_widths, 0)] * tf.ones_like(input_box_widths)
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        shared_feature, f_score, f_geometry = detect_part.model(input_images)
        pad_rois = roi_rotate_part.roi_rotate_tensor_pad(
            input_feature_map, input_transform_matrix, input_box_mask,
            input_box_widths)
        recognition_logits = recognize_part.build_graph(
            pad_rois, input_box_widths)
        _, dense_decode = recognize_part.decode(recognition_logits,
                                                input_box_widths)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            # im_fn_list = get_images()
            if FLAGS.just_infer:
                im_fn_list, _, _ = get_image_self(
                    "/data/ceph_11015/ssd/anhan/nba/video2image")
            else:
                im_fn_list, corridate_list, label_list = get_image_self(
                    "/data/ceph_11015/ssd/anhan/nba/video2image")
            wrong = 0
            total = 0
            for ind, im_fn in enumerate(im_fn_list):
                #print("im_fn:",im_fn)
                im = cv2.imread(im_fn)[:, :, ::-1]
                im = cv2.resize(im, (960, 540))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                # im_resized_d, (ratio_h_d, ratio_w_d) = resize_image_detection(im)

                timer = {'detect': 0, 'restore': 0, 'nms': 0, 'recog': 0}
                start = time.time()
                shared_feature_map, score, geometry = sess.run(
                    [shared_feature, f_score, f_geometry],
                    feed_dict={input_images: [im_resized]})

                boxes, timer = detect(score_map=score,
                                      geo_map=geometry,
                                      timer=timer)
                timer['detect'] = time.time() - start
                start = time.time()  # reset for recognition
                res = None
                str_list = []
                if boxes is not None and boxes.shape[0] != 0:
                    #res_file_path = os.path.join(FLAGS.output_dir,'res_' + '{}.txt'.format(os.path.basename(im_fn).split('.')[0]))
                    # res_file_path = os.path.join(FLAGS.output_dir, '{}.txt'.format(os.path.basename(im_fn)))

                    input_roi_boxes = boxes[:, :8].reshape(-1, 8)
                    recog_decode_list = []
                    # Here avoid too many text area leading to OOM
                    for batch_index in range(input_roi_boxes.shape[0] // 32 +
                                             1):  # test roi batch size is 32
                        start_slice_index = batch_index * 32
                        end_slice_index = (
                            batch_index + 1
                        ) * 32 if input_roi_boxes.shape[0] >= (
                            batch_index + 1) * 32 else input_roi_boxes.shape[0]
                        tmp_roi_boxes = input_roi_boxes[
                            start_slice_index:end_slice_index]

                        boxes_masks = [0] * tmp_roi_boxes.shape[0]
                        transform_matrixes, box_widths = get_project_matrix_and_width(
                            tmp_roi_boxes)
                        #max_box_widths = max_width * np.ones(boxes_masks.shape[0]) # seq_len

                        # Run end to end
                        recog_decode = sess.run(dense_decode,
                                                feed_dict={
                                                    input_feature_map:
                                                    shared_feature_map,
                                                    input_transform_matrix:
                                                    transform_matrixes,
                                                    input_box_mask[0]:
                                                    boxes_masks,
                                                    input_box_widths:
                                                    box_widths
                                                })
                        recog_decode_list.extend([r for r in recog_decode])

                    timer['recog'] = time.time() - start
                    # Preparing for draw boxes
                    boxes = boxes[:, :8].reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h

                    if len(recog_decode_list) != boxes.shape[0]:
                        print(
                            "detection and recognition result are not equal!")
                        exit(-1)

                    scores = {}
                    score_index = 0
                    time_left = {}
                    time_index = 0
                    team_name = {}
                    quarter_dict = {}
                    remainder_attack_time = {}
                    remainder_attack_time_index = 0
                    recognition_result_num = 0
                    points = {}
                    for i, box in enumerate(boxes):
                        # to avoid submitting errors
                        box = sort_poly(box.astype(np.int32))
                        if np.linalg.norm(box[0] -
                                          box[1]) < 5 or np.linalg.norm(
                                              box[3] - box[0]) < 5:
                            continue
                        recognition_result = ground_truth_to_word(
                            recog_decode_list[i])

                        if contain_eng(recognition_result):
                            #print(recognition_result)
                            fix_result = bktree_search(
                                bk_tree, recognition_result.lower())
                            #print(fix_result)
                            if len(fix_result) != 0:
                                recognition_result = fix_result[0][1]
                                #print(recognition_result)
                        else:
                            recognition_result = recognition_result

                        if recognition_result in all_team:
                            team_name[recognition_result] = [
                                (int(box[0, 0]) + int(box[2, 0])) / 2,
                                (int(box[0, 1]) + int(box[2, 1])) / 2
                            ]
                            points[recognition_result] = [box[0, 0], box[2, 0]]

                        if recognition_result in quarter:
                            quarter_dict[recognition_result] = [
                                (int(box[0, 0]) + int(box[2, 0])) / 2,
                                (int(box[0, 1]) + int(box[2, 1])) / 2
                            ]
                            points[recognition_result] = [box[0, 0], box[2, 0]]

                        if recognition_result.isdigit():
                            scores[recognition_result + "_" +
                                   str(score_index)] = [
                                       (int(box[0, 0]) + int(box[2, 0])) / 2,
                                       (int(box[0, 1]) + int(box[2, 1])) / 2
                                   ]
                            points[recognition_result + "_" +
                                   str(score_index)] = [box[0, 0], box[2, 0]]
                            score_index += 1

                        if ":" in recognition_result:
                            time_left[recognition_result + "_" +
                                      str(time_index)] = [
                                          (int(box[0, 0]) + int(box[2, 0])) /
                                          2,
                                          (int(box[0, 1]) + int(box[2, 1])) / 2
                                      ]
                            points[recognition_result + "_" +
                                   str(time_index)] = [box[0, 0], box[2, 0]]
                            time_index += 1

                        if "." in recognition_result and ":" not in recognition_result:
                            remainder_attack_time[
                                recognition_result + "_" +
                                str(remainder_attack_time_index)] = [
                                    (int(box[0, 0]) + int(box[2, 0])) / 2,
                                    (int(box[0, 1]) + int(box[2, 1])) / 2
                                ]
                            points[recognition_result + "_" +
                                   str(remainder_attack_time_index)] = [
                                       box[0, 0], box[2, 0]
                                   ]
                            remainder_attack_time_index += 1

                        recognition_result_num += 1
                        str_list.append(recognition_result)
                        # Draw bounding box
                        # cv2.polylines(im, [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
                        # Draw recognition results area
                        # text_area = box.copy()
                        # text_area[2, 1] = text_area[1, 1]
                        # text_area[3, 1] = text_area[0, 1]
                        # text_area[0, 1] = text_area[0, 1] - 15
                        # text_area[1, 1] = text_area[1, 1] - 15
                        # cv2.fillPoly(im, [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0))
                        # im_txt = cv2.putText(im, recognition_result, (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), 1)
                        # 中文文字添加:
                        # im_txt = cv2ImgAddText(im, recognition_result, box[0, 0], box[0, 1], (0, 0, 149), 20)

                    if recognition_result_num == 7 or recognition_result_num == 6 or recognition_result_num == 5 or recognition_result_num == 8:
                        res = get_content(remainder_attack_time, time_left,
                                          team_name, scores, quarter_dict)
                    elif recognition_result_num == 9:
                        sort_points = sorted(points.items(),
                                             key=lambda item: item[1][0])
                        x_coordiate = []
                        for pair in sort_points:
                            x_coordiate.append(pair[1][0])
                            x_coordiate.append(pair[1][1])
                        x_sort = sorted(x_coordiate)
                        if x_sort == x_coordiate:
                            drop1 = sort_points[1][0]
                            drop2 = sort_points[4][0]
                            if drop1 in remainder_attack_time:
                                remainder_attack_time = remove_key(
                                    remainder_attack_time, drop1)
                            if drop2 in remainder_attack_time:
                                remainder_attack_time = remove_key(
                                    remainder_attack_time, drop2)
                            if drop1 in time_left:
                                time_left = remove_key(time_left, drop1)
                            if drop2 in time_left:
                                time_left = remove_key(time_left, drop2)
                            if drop1 in scores:
                                scores = remove_key(scores, drop1)
                            if drop2 in scores:
                                scores = remove_key(scores, drop2)
                        res = get_content(remainder_attack_time, time_left,
                                          team_name, scores, quarter_dict)
                if not FLAGS.just_infer:
                    corridate_true = corridate_list[ind].split("_")[4:]
                    label_true = label_list[ind].split("_")
                    res_true = get_score_info_v2(corridate_true, label_true)
                    if res != res_true:
                        #print(im_fn.split("/")[-1],'wrong!!!')
                        wrong += 1
                        #print(im_fn.split("/")[-1],label_list[ind],res_true,res,("_").join(str_list))
                    total += 1
                    print(
                        im_fn.split("/")[-1], label_list[ind], res_true, res,
                        ("_").join(str_list))
                else:
                    print(im_fn.split("/")[-1], res, ("_").join(str_list))
                duration = time.time() - start_time
                #print('{} : detect {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms, recog {:.0f}ms'.format(im_fn, timer['detect']*1000, timer['restore']*1000, timer['nms']*1000, timer['recog']*1000))
            print("wrong:{}".format(wrong))
            print("total:{}".format(total))
            print("precision:{}".format((total - wrong) / total))