コード例 #1
0
ファイル: query.py プロジェクト: SebiSebi/DataMine
def get_similar_sentences(query):
    embeddings = load_embeddings()
    sentence_ids = load_sentence_ids()
    index = AnnoyIndex(get_embeddings_dim(embeddings), "angular")
    index.load("index.ann")
    print("Found {} items in the index.".format(index.get_n_items()))
    print("The index uses {} trees.".format(index.get_n_trees()))
    print("")
    closest, dists = index.get_nns_by_vector(
        embeddings[query], 10, include_distances=True)  # noqa: E501
    assert (len(closest) == len(dists))
    closest = map(lambda sid: sentence_ids[sid], closest)
    return zip(closest, dists)
コード例 #2
0
ファイル: query.py プロジェクト: SebiSebi/DataMine
def annotate_all_questions():
    embeddings = load_embeddings()
    sentence_ids = load_sentence_ids()
    index = AnnoyIndex(get_embeddings_dim(embeddings), "angular")
    index.load("index.ann")
    print("Found {} items in the index.".format(index.get_n_items()))
    print("The index uses {} trees.".format(index.get_n_trees()))
    print("")

    df = pd.concat(map(dm.ALLEN_AI_OBQA, list(OBQAType)))
    annotations = {}
    for _, row in tqdm.tqdm(df.iterrows(), total=len(df)):
        for answer in row.answers:
            sent = row.question + " " + answer
            closest = index.get_nns_by_vector(embeddings[sent], 75)
            closest = list(map(lambda sid: sentence_ids[sid], closest))
            annotations[sent] = closest
    pickle.dump(annotations, open("annotations.pkl", "wb"))
    print("Annotations written to annotations.pkl")
コード例 #3
0
ファイル: index_test.py プロジェクト: zzszmyf/annoy
    def test_very_large_index(self):
        # 388
        f = 3
        dangerous_size = 2**31
        size_per_vector = 4 * (f + 3)
        n_vectors = int(dangerous_size / size_per_vector)
        m = AnnoyIndex(3, 'angular')
        m.verbose(True)
        for i in range(100):
            m.add_item(n_vectors + i, [random.gauss(0, 1) for z in range(f)])
        n_trees = 10
        m.build(n_trees)
        path = 'test_big.annoy'
        m.save(path)  # Raises on Windows

        # Sanity check size of index
        self.assertGreaterEqual(os.path.getsize(path), dangerous_size)
        self.assertLess(os.path.getsize(path), dangerous_size + 100e3)

        # Sanity check number of trees
        self.assertEquals(m.get_n_trees(), n_trees)
コード例 #4
0
class AnnoySearch:
    def __init__(self,
                 vec_dim=2048,
                 lmdb_file="static/lmdb",
                 ann_file="static/annoy_file/tree.ann",
                 metric='angular',
                 num_trees=10):
        self.vec_dim = vec_dim  # 要index的向量维度
        self.metric = metric  # 度量可以是"angular","euclidean","manhattan","hamming",或"dot"
        self.annoy_instance = AnnoyIndex(self.vec_dim, self.metric)
        self.lmdb_file = lmdb_file
        self.ann_file = ann_file
        self.num_trees = num_trees
        self.logger = logging.getLogger('AnnoySearch')

    def save_annoy(self):
        self.annoy_instance.save(self.ann_file)
        self.logger.info('save annoy SUCCESS !')

    def unload_annoy(self):
        self.annoy_instance.unload()

    def load_annoy(self):
        try:
            self.annoy_instance.unload()
            self.annoy_instance.load(self.ann_file)
            self.logger.info('load annoy SUCCESS !')
        except FileNotFoundError:
            self.logger.error(
                'annoy file DOES NOT EXIST , load annoy FAILURE !',
                exc_info=True)
        # 创建annoy索引

    def create_index_from_lmdb(self):
        # 遍历
        lmdb_file = self.lmdb_file
        if os.path.isdir(lmdb_file):
            evn = lmdb.open(lmdb_file)
            wfp = evn.begin()
            for key, value in wfp.cursor():
                key = int(key)
                value = str2embed(value)
                print(len(value))
                self.annoy_instance.add_item(key, value)

            self.annoy_instance.build(self.num_trees)
            self.annoy_instance.save(self.ann_file)

    def build_annoy(self):
        self.annoy_instance.build(self.num_trees)

    def get_nns_by_item(self,
                        index,
                        nn_num,
                        search_k=-1,
                        include_distances=False):
        return self.annoy_instance.get_nns_by_item(index, nn_num, search_k,
                                                   include_distances)

    def get_nns_by_vector(self,
                          vec,
                          nn_num,
                          search_k=-1,
                          include_distances=False):
        return self.annoy_instance.get_nns_by_vector(vec, nn_num, search_k,
                                                     include_distances)

    def get_n_items(self):
        return self.annoy_instance.get_n_items()

    def get_n_trees(self):
        return self.annoy_instance.get_n_trees()

    def get_vec_dim(self):
        return self.vec_dim

    def add_item(self, index, vec):
        self.annoy_instance.add_item(index, vec)

    def get_item_vector(self, index):
        return self.annoy_instance.get_item_vector(index)
コード例 #5
0
ファイル: index_test.py プロジェクト: zzszmyf/annoy
 def test_get_n_trees(self):
     i = AnnoyIndex(10, 'angular')
     i.load('test/test.tree')
     self.assertEqual(i.get_n_trees(), 10)
コード例 #6
0
ファイル: find_match.py プロジェクト: evanc577/sourcecatcher
def find(location, path):
    """find the closest images to an image

    Given a path or a url to an image, returns the closest matches
    (phash hamming distance)

    Arguments:
    location: 'url' or 'path'
    path: the actual url or path to the image
    """

    # load database and annoy index
    index = AnnoyIndex(64, metric='hamming')
    index.load('live/phash_index.ann')
    conn = sqlite3.connect('live/twitter_scraper.db')
    c = conn.cursor()

    # load the requested image
    img = load_image(location, path)

    start_time = time.time()

    # get the image's phash
    phash = imagehash.phash(img)
    phash_arr = phash.hash.flatten()

    # find the closest matches
    n = 16
    n_trees = index.get_n_trees()
    ann_start_time = time.time()
    annoy_results = index.get_nns_by_vector(phash_arr,
                                            n,
                                            include_distances=True,
                                            search_k=100 * n * n_trees)
    ann_end_time = time.time()

    # look up the location of the match and its tweet info
    results = []
    for idx, score in map(list, zip(*annoy_results)):
        # only keep close enough matches
        if score > 8:
            break

        # find respective image in database
        c.execute('SELECT path, filename FROM hashes WHERE idx=(?)', (idx, ))
        dirname, basename = c.fetchone()
        c.execute('SELECT id FROM info WHERE filename=(?) AND path=(?)',
                  (basename, dirname))
        tweet_id = c.fetchone()
        tweet_id = tweet_id[0]
        results.append((score, tweet_id, basename))

    conn.close()

    # sort results
    results = sorted(results, key=lambda x: (-x[0], x[1]))

    end_time = time.time()

    print(results)
    print(f"total search time (phash): {end_time - start_time:06f} seconds")
    print(
        f"annoy search time (phash): {ann_end_time - ann_start_time:06f} seconds"
    )

    return results
コード例 #7
0
ファイル: retriever.py プロジェクト: JamesHujy/ELV
class Retriever(object):
    def __init__(self, args, tokenizer, hidden_size=768):
        self.embedmodel = SentenceTransformer('bert-base-nli-mean-tokens')
        self.tokenizer = tokenizer
        self.embed_list = []
        self.sentence_list = []
        self.sentence2exp = {}
        self.annoy = AnnoyIndex(hidden_size, metric='angular')
        self.load_exp(args.labeled_data)
        self.loaded = False
        # self.load_unlabeled_sen(args.unlabeled_data)

    def load_exp(self, filename):
        sentence_list = []
        with open(filename, encoding='utf-8') as f:
            json_file = json.load(f)
            for item in json_file:
                if 'term' in item:
                    sentence = item['term'] + self.tokenizer.sep_token + item[
                        'sent']
                else:
                    sentence = item['sent']
                exp = item['exp']
                sentence_list.append(sentence)
                self.sentence2exp[sentence] = exp
        self.embed_list.extend(
            self.embedmodel.encode(sentence_list, show_progress_bar=True))
        for i, embed in enumerate(self.embed_list):
            self.annoy.add_item(i, embed)
        self.sentence_list.extend(sentence_list)

    def load_unlabeled_sen(self, filename):
        sentence_list = []
        with open(filename, encoding='utf-8') as f:
            json_file = json.load(f)
            for item in json_file:
                if 'term' in item:
                    sentence = item['term'] + self.tokenizer.sep_token + item[
                        'sent']
                else:
                    sentence = item['sent']
                sentence_list.append(sentence)
        self.sentence_list.extend(sentence_list)
        self.embed_list.extend(
            self.embedmodel.encode(sentence_list, show_progress_bar=True))
        for i, embed in enumerate(self.embed_list):
            self.annoy.add_item(i, embed)
        self.loaded = True

    def update_exp(self, sentence_list, exp_list):
        for sen, exp in zip(sentence_list, exp_list):
            self.sentence2exp[sen] = exp

    def retrieve(self, sentence, nums, get_similar=False):
        if self.annoy.get_n_trees() == 0:
            self.annoy.build(100)

        if sentence in self.sentence_list:
            index = self.sentence_list.index(sentence)
            result = self.annoy.get_nns_by_item(index,
                                                nums,
                                                include_distances=get_similar)
        else:
            target_embed = self.embedmodel.encode([sentence],
                                                  show_progress_bar=False)[0]
            result = self.annoy.get_nns_by_vector(
                target_embed, nums, include_distances=get_similar)

        if get_similar:
            candidate_exp = [
                self.sentence2exp[self.sentence_list[index]]
                for index in result[0]
            ]
            return candidate_exp, result[1]
        else:
            candidate_exp = [
                self.sentence2exp[self.sentence_list[index]]
                for index in result
            ]
            return candidate_exp

    def compute_sim(self, embed1, embed2):
        return np.dot(embed1, embed2)
コード例 #8
0
class AnnoySearch:
    def __init__(self, vec_dim=100, metric='angular'):
        self.vec_dim = vec_dim  # 要index的向量维度
        self.metric = metric  # 度量可以是"angular","euclidean","manhattan","hamming",或"dot"
        self.annoy_instance = AnnoyIndex(self.vec_dim, self.metric)
        self.logger = logging.getLogger('AnnoySearch')

    def save_annoy(self, annoy_file, prefault=False):
        self.annoy_instance.save(annoy_file, prefault=prefault)
        self.logger.info('save annoy SUCCESS !')

    def unload_annoy(self):
        self.annoy_instance.unload()

    def load_annoy(self, annoy_file, prefault=False):
        try:
            self.annoy_instance.unload()
            self.annoy_instance.load(annoy_file, prefault=prefault)
            self.logger.info('load annoy SUCCESS !')
        except FileNotFoundError:
            self.logger.error(
                'annoy file DOES NOT EXIST , load annoy FAILURE !',
                exc_info=True)

    # 创建annoy索引
    def build_annoy(self, n_trees):
        self.annoy_instance.build(n_trees)

    # 查询最近邻,通过index
    def get_nns_by_item(self,
                        index,
                        nn_num,
                        search_k=-1,
                        include_distances=False):
        return self.annoy_instance.get_nns_by_item(index, nn_num, search_k,
                                                   include_distances)

    # 查询最近邻,通过向量
    def get_nns_by_vector(self,
                          vec,
                          nn_num,
                          search_k=-1,
                          include_distances=False):
        return self.annoy_instance.get_nns_by_vector(vec, nn_num, search_k,
                                                     include_distances)

    def get_n_items(self):
        return self.annoy_instance.get_n_items()

    def get_n_trees(self):
        return self.annoy_instance.get_n_trees()

    def get_vec_dim(self):
        return self.vec_dim

    # 添加item
    def add_item(self, index, vec):
        self.annoy_instance.add_item(index, vec)

    def get_item_vector(self, index):
        return self.annoy_instance.get_item_vector(index)
コード例 #9
0
def find_similar(img_path, location='file'):
    print(img_path)
    global kmeans

    # load files
    annoy_map = joblib.load('live/BOW_annoy_map.pkl')
    kmeans = joblib.load('live/kmeans.pkl')

    index = AnnoyIndex(kmeans.n_clusters, 'angular')
    index.load('live/BOW_index.ann')

    conn = sqlite3.connect('live/twitter_scraper.db')
    c = conn.cursor()

    # compute histogram
    start_time = time.time()
    try:
        hist = image_detect_and_compute(img_path, location=location)
    except cv2.error:
        return []

    # find most similar images
    n = 12
    n_trees = index.get_n_trees()
    ann_start_time = time.time()
    annoy_results = index.get_nns_by_vector(hist,
                                            n,
                                            include_distances=True,
                                            search_k=-1)
    ann_end_time = time.time()

    # process results
    results = []
    max_score = -1
    for i, idx in enumerate(annoy_results[0]):
        # discard bad results
        if annoy_results[1][i] > 1.0:
            break

        score = int(100 * (1 - annoy_results[1][i]))
        if i == 0:
            max_score = score
        elif max_score - score > 10:
            break

        # get tweet info
        path = annoy_map[idx]
        basename = os.path.basename(path)
        dirname = os.path.dirname(path)
        c.execute('SELECT id FROM info WHERE filename=(?) AND path=(?)',
                  (basename, dirname))
        tweet_id = c.fetchone()[0]
        tup = (
            score,
            tweet_id,
            basename,
        )
        results.append(tup)

    end_time = time.time()

    print(results)
    print(f"total search time (cbir): {end_time - start_time:06f} seconds")
    print(
        f"annoy search time (cbir): {ann_end_time - ann_start_time:06f} seconds"
    )

    return results
コード例 #10
0
ファイル: index_test.py プロジェクト: spotify/annoy
 def test_get_n_trees(self):
     i = AnnoyIndex(10)
     i.load('test/test.tree')
     self.assertEqual(i.get_n_trees(), 10)
コード例 #11
0
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 13 10:57:22 2019

@author: Kazem
"""

from annoy import AnnoyIndex
dim = 2
a = AnnoyIndex(dim,'euclidean')

points = [[1,0],[2,1],[1,3],[-2,2],[-1,-2],[3,-1],[-1,3]]

for i,v in enumerate(points):
    a.add_item(i,v)


print(str(a.get_n_items())+ "  items are added")

a.build(30)

print(str(a.get_n_trees())+ "  trees are made")

result = a.get_nns_by_item(6,2)


items = a.get_nns_by_vector([0,4],2)