コード例 #1
0
 def test_get_item_vector(self):
     f = 10
     i = AnnoyIndex(f, 'euclidean')
     i.add_item(0, [random.gauss(0, 1) for x in xrange(f)])
     for j in xrange(100):
         print(j, '...')
         for k in xrange(1000 * 1000):
             i.get_item_vector(0)
コード例 #2
0
 def test_get_item_vector(self):
     f = 10
     i = AnnoyIndex(f, 'euclidean')
     i.add_item(0, [random.gauss(0, 1) for x in xrange(f)])
     for j in xrange(100):
         print(j, '...')
         for k in xrange(1000 * 1000):
             i.get_item_vector(0)
コード例 #3
0
 def test_load_save(self):
     # Issue #61
     i = AnnoyIndex(10)
     i.load('test/test.tree')
     u = i.get_item_vector(99)
     i.save('x.tree')
     v = i.get_item_vector(99)
     self.assertEqual(u, v)
     j = AnnoyIndex(10)
     j.load('test/test.tree')
     w = i.get_item_vector(99)
     self.assertEqual(u, w)
コード例 #4
0
 def test_load_save(self):
     # Issue #61
     i = AnnoyIndex(10)
     i.load('test/test.tree')
     u = i.get_item_vector(99)
     i.save('x.tree')
     v = i.get_item_vector(99)
     self.assertEquals(u, v)
     j = AnnoyIndex(10)
     j.load('test/test.tree')
     w = i.get_item_vector(99)
     self.assertEquals(u, w)
コード例 #5
0
 def test_item_vector_after_save(self):
     # Issue #279
     a = AnnoyIndex(3)
     a.verbose(True)
     a.add_item(1, [1, 0, 0])
     a.add_item(2, [0, 1, 0])
     a.add_item(3, [0, 0, 1])
     a.build(-1)
     self.assertEquals(a.get_n_items(), 4)
     a.get_item_vector(3)
     a.save('something.annoy')
     self.assertEquals(a.get_n_items(), 4)
     a.get_item_vector(3)
コード例 #6
0
    def test_load_save_get_item_vector(self):
        f = 3
        i = AnnoyIndex(f)
        i.add_item(0, [1.1, 2.2, 3.3])
        i.add_item(1, [4.4, 5.5, 6.6])
        i.add_item(2, [7.7, 8.8, 9.9])
 
        numpy.testing.assert_array_almost_equal(i.get_item_vector(0), [1.1, 2.2, 3.3])
        self.assertTrue(i.build(10))
        self.assertTrue(i.save('blah.ann'))
        numpy.testing.assert_array_almost_equal(i.get_item_vector(1), [4.4, 5.5, 6.6])
        j = AnnoyIndex(f)
        self.assertTrue(j.load('blah.ann'))
        numpy.testing.assert_array_almost_equal(j.get_item_vector(2), [7.7, 8.8, 9.9])
コード例 #7
0
ファイル: annoy_test.py プロジェクト: wangxggc/annoy
    def test_load_save_get_item_vector(self):
        f = 3
        i = AnnoyIndex(f)
        i.add_item(0, [1.1, 2.2, 3.3])
        i.add_item(1, [4.4, 5.5, 6.6])
        i.add_item(2, [7.7, 8.8, 9.9])
 
        numpy.testing.assert_array_almost_equal(i.get_item_vector(0), [1.1, 2.2, 3.3])
        self.assertTrue(i.build(10))
        self.assertTrue(i.save('blah.ann'))
        numpy.testing.assert_array_almost_equal(i.get_item_vector(1), [4.4, 5.5, 6.6])
        j = AnnoyIndex(f)
        self.assertTrue(j.load('blah.ann'))
        numpy.testing.assert_array_almost_equal(j.get_item_vector(2), [7.7, 8.8, 9.9])
コード例 #8
0
ファイル: index_test.py プロジェクト: zzszmyf/annoy
 def test_item_vector_after_save(self):
     # Issue #279
     a = AnnoyIndex(3, 'angular')
     a.verbose(True)
     a.add_item(1, [1, 0, 0])
     a.add_item(2, [0, 1, 0])
     a.add_item(3, [0, 0, 1])
     a.build(-1)
     self.assertEqual(a.get_n_items(), 4)
     self.assertEqual(a.get_item_vector(3), [0, 0, 1])
     self.assertEqual(set(a.get_nns_by_item(1, 999)), set([1, 2, 3]))
     a.save('something.annoy')
     self.assertEqual(a.get_n_items(), 4)
     self.assertEqual(a.get_item_vector(3), [0, 0, 1])
     self.assertEqual(set(a.get_nns_by_item(1, 999)), set([1, 2, 3]))
コード例 #9
0
ファイル: index_test.py プロジェクト: spotify/annoy
 def test_item_vector_after_save(self):
     # Issue #279
     a = AnnoyIndex(3)
     a.verbose(True)
     a.add_item(1, [1, 0, 0])
     a.add_item(2, [0, 1, 0])
     a.add_item(3, [0, 0, 1])
     a.build(-1)
     self.assertEqual(a.get_n_items(), 4)
     self.assertEqual(a.get_item_vector(3), [0, 0, 1])
     self.assertEqual(set(a.get_nns_by_item(1, 999)), set([1, 2, 3]))
     a.save('something.annoy')
     self.assertEqual(a.get_n_items(), 4)
     self.assertEqual(a.get_item_vector(3), [0, 0, 1])
     self.assertEqual(set(a.get_nns_by_item(1, 999)), set([1, 2, 3]))
コード例 #10
0
ファイル: annoy_analysis.py プロジェクト: tirthshah20/Sampo
def measure_stability(annoy_file, extractions, dimension,
                      queries=20, topk=20, repetitions=10):
    # loading the index
    t = AnnoyIndex(dimension, 'angular')
    t.load(annoy_file)
    # reading all vectors
    vecs = [t.get_item_vector(i) for i in range(extractions)]
    # sampling quries
    q_inds = random.sample(list(range(extractions)), queries)
    # repeating the process of building indices
    all_indices = [t]
    for _ in tqdm(range(repetitions - 1)):
        t = AnnoyIndex(dimension, 'angular')
        for i, v in enumerate(vecs):
            t.add_item(i, v)
        t.build(100)
        all_indices.append(t)
    # checking if the results are the same
    inconsistencies = 0
    for q in tqdm(q_inds):
        all_nns = set()
        for t in all_indices:
            all_nns.update(t.get_nns_by_item(q, topk))
        if len(all_nns) != topk:
            inconsistencies += 1
    # printing the results
    print('{} of the queries had inconsistent neighbors.'.format(inconsistencies
                                                                 / queries))
コード例 #11
0
class MatchingUtil:
    def __init__(self, index_file):
        logging.info('Initialising matching utility...')
        self.index = AnnoyIndex(VECTOR_LENGTH)
        self.index.load(index_file, prefault=True)
        logging.info('Annoy index {} is loaded'.format(index_file))
        with open(index_file + '.mapping', 'rb') as handle:
            self.mapping = pickle.load(handle)
        logging.info('Mapping file {} is loaded'.format(index_file +
                                                        '.mapping'))
        logging.info('Matching utility initialised.')

    def find_similar_items(self, vector, num_matches):
        item_ids = self.index.get_nns_by_vector(vector,
                                                num_matches,
                                                search_k=-1,
                                                include_distances=False)
        identifiers = [self.mapping[item_id] for item_id in item_ids]
        return identifiers

    def find_similar_vectors(self, vector, num_matches):
        items = self.find_similar_items(vector, num_matches)
        vectors = [
            np.array(self.index.get_item_vector(item)) for item in items
        ]
        return vectors
コード例 #12
0
def prepare(model: Model, label: Label) -> Dataset:
    if not path.exists(PREPARE_PATH):
        mkdir(PREPARE_PATH)

    features = []
    labels = []

    prepare_file = PREPARE_TEMPLATE.format(path=PREPARE_PATH,
                                           model=model.name,
                                           label=label.name)
    if path.isfile(prepare_file):
        with open(prepare_file, mode="r", encoding="utf-8") as f:
            for line in f:
                row = json.loads(line)
                features.append(row["feature"])
                labels.append(row["label"])
        return split(
            X=tf.convert_to_tensor(features),
            Y=tf.convert_to_tensor(labels),
        )
    else:
        examples = []
        title_tree = AnnoyIndex(SIZE[model], "angular")
        title_tree.load(TREE[model][label])

        with open(FEATURES_FILE, mode="r", encoding="utf-8") as f:
            for line in f:
                row = json.loads(line)
                examples.append((
                    row["image"][model.name],
                    title_tree.get_item_vector(row["title_index"]),
                    Recipe(
                        id=row["id"],
                        title=row["title"],
                        title_index=row["title_index"],
                        ingredients=row["ingredients"],
                    ),
                ))
        random.shuffle(examples)

        with open(prepare_file, mode="w", encoding="utf-8") as f:
            for example in examples:
                feature, label, recipe = example
                features.append(feature)
                labels.append(label)
                row = {
                    "feature": feature,
                    "label": label,
                    "id": recipe.id,
                    "title": recipe.title,
                    "title_index": recipe.title_index,
                    "ingredients": recipe.ingredients,
                }
                f.write(f"{json.dumps(row)}\n")

        return split(
            X=tf.convert_to_tensor(features),
            Y=tf.convert_to_tensor(labels),
        )
コード例 #13
0
def findNN(x_dim: int,
           y_dim: int,
           codebook_index_path: str,
           codebook_path: str,
           bit_len: int,
           img_path_list: str,
           img_prefix: str,
           threshold=0.5176,
           search_k=-1):
    df = parseBarcodes(codebook_path, bit_len=bit_len)
    # image_array = np.array(([0.94,0,0], [0,0.68,0], [0, 0, 0.73]))
    # First load the annoy index
    u = AnnoyIndex(bit_len, 'euclidean')
    u.load(codebook_index_path)

    # Parse the img_path_list to sort in ascending order, cause they'll be randomized due to nextflow's non-order
    r = re.compile(rf"{img_prefix}(\d+)")

    def key_func(m):
        return int(r.search(m).group(1))

    img_path_list.sort(key=key_func)

    # read all images and convert them to float
    image_list = [img_as_float(io.imread(img)) for img in img_path_list]
    rows_list = []
    for x in range(0, x_dim):
        for y in range(0, y_dim):
            attribute_dict = {}
            attribute_dict['X'] = x
            attribute_dict['Y'] = y
            pixel_vector = createPixelVector(x, y, image_list)
            index, minimal_distance = [
                element[0] for element in u.get_nns_by_vector(
                    pixel_vector, 1, search_k=-1, include_distances=True)
            ]  # list comprheension is necessary because the get_nns_by_vector funciton returns a list instead of just values
            nearest_codebook_vector = np.array(u.get_item_vector(index))
            # Create mask of the vector column, where only 1 will be true
            mask = df.Vector.apply(
                lambda x: str(x) == str(nearest_codebook_vector)
            )  # casting to string is necessary to make the comparisan correct
            match_row = df[mask]
            for row in match_row.itertuples(
            ):  # 'iterate' over rows just so I can get a row object, otherwise indexing the match_row object is even more wonky
                attribute_dict['Barcode'] = row.Barcode
                attribute_dict['Distance'] = minimal_distance
                attribute_dict['Gene'] = row.Gene
                # If minimal distance not passing the threshold, it will be labeled as background
                if minimal_distance > threshold:
                    gene_label = 0
                else:
                    gene_label = row.Index
                attribute_dict['Gene_Label'] = gene_label
            rows_list.append(attribute_dict)
    result_df = pd.DataFrame(rows_list)
    return result_df
コード例 #14
0
    def evaluate_set(prefix,
                     tails,
                     annoy_tree_file,
                     vector_dims,
                     lock,
                     rank_threshold=100,
                     sample_size=1000):

        #fname = ''.join(annoy_tree_file)
        lock.acquire()
        try:
            annoy_tree = AnnoyIndex(vector_dims)
            annoy_tree.load(annoy_tree_file)
        finally:
            lock.release()

        # annoy_tree = load_annoy_tree(annoy_tree_file, vector_dims)

        print(mp.current_process().name, id(annoy_tree),
              prefix.encode('utf-8'))
        sys.stdout.flush()

        counts = dict()
        counts[True] = 0
        counts[False] = 0

        if len(tails) > sample_size:
            tails = random.sample(tails, sample_size)
        for (comp1, tail1), (comp2, tail2) in itertools.combinations(tails, 2):

            diff = np.array(annoy_tree.get_item_vector(comp2)) - np.array(
                annoy_tree.get_item_vector(tail2))
            predicted = np.array(annoy_tree.get_item_vector(tail1)) + diff

            result = annoy_knn(annoy_tree, predicted, comp1, rank_threshold)

            counts[result] += 1

        annoy_tree.unload(annoy_tree_file)

        return (prefix, float(counts[True]) / (counts[True] + counts[False])
                ) if counts[True] + counts[False] > 0 else (prefix, 0.0)
コード例 #15
0
def load_embeddings(index_path, embedding_size, num_nodes):
    # Load Annoy index which stores the embedded vectors
    index = AnnoyIndex(embedding_size)
    index.load(index_path)

    embeddings = [index.get_item_vector(i) for i in range(num_nodes)]

    # Unload the index to save memory (loading mmaps the index file)
    index.unload()

    # V x D matrix of embeddings
    return np.array(embeddings)
コード例 #16
0
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f, 'dot')
     for j in range(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(
                 dist, numpy.dot(i.get_item_vector(a),
                                 i.get_item_vector(b)))
             self.assertAlmostEqual(dist, i.get_distance(a, b))
コード例 #17
0
ファイル: index_builder.py プロジェクト: kostyaev/ann-search
    def merge_indicies(self, index_file_a, index_file_b, sender_urn):
        logger.info("Merging {0} and {1} for {2} index".format(index_file_a, index_file_b, sender_urn))
        index_a = AnnoyIndex(self.feat_size, metric='euclidean')
        index_b = AnnoyIndex(self.feat_size, metric='euclidean')
        new_index = AnnoyIndex(self.feat_size, metric='euclidean')

        index_a.load(index_file_a)
        index_b.load(index_file_b)

        cnt = 0
        for i in range(index_a.get_n_items()):
            new_index.add_item(cnt, index_a.get_item_vector(i))
            cnt += 1

        for i in range(index_b.get_n_items()):
            new_index.add_item(cnt, index_b.get_item_vector(i))
            cnt += 1


        new_index_file = index_file_a + ".merged"

        index_a.unload()
        index_b.unload()

        new_index.build(self.n_trees)
        new_index.save(new_index_file)
        logger.info("Merging {0} and {1} for {2} index, total number of items: {3}".format(
                index_file_a,
                index_file_b,
                sender_urn,
                cnt))

        new_index.unload()
        pykka.ActorRegistry.get_by_urn(sender_urn).proxy().complete_compaction(
                new_index_file=new_index_file,
                index_file_a=index_file_a,
                index_file_b=index_file_b
        )
コード例 #18
0
class AnnoyDataIndex(tf.keras.callbacks.Callback):
    def __init__(self,
                 eb_size,
                 labels,
                 metric="euclidean",
                 save_dir=None,
                 progress=True,
                 **kwargs):
        super().__init__(**kwargs)

        self.progress = progress
        self.index = None
        self.metric = metric
        self.eb_size = eb_size
        self.save_dir = save_dir
        self.labels = labels
        self.ids = self.create_ids(labels)

    def create_ids(self, labels):
        return {i: label for i, label in enumerate(labels)}

    def get_label(self, index):
        return self.ids[index]

    def load_index_file(self, file_path):
        self.index = AnnoyIndex(self.eb_size, self.metric)
        self.index.load(file_path, prefault=False)

    def reindex(self, embeddings):
        self.index = AnnoyIndex(self.eb_size, self.metric)

        for i, embedding in tqdm(enumerate(embeddings),
                                 ncols=100,
                                 total=len(embeddings),
                                 disable=not self.progress,
                                 desc="Indexing ... "):
            self.index.add_item(i, embedding)

        self.index.build(10)

        if self.save_dir:
            os.makedirs(self.save_dir, exist_ok=True)
            self.index.save(os.path.join(self.save_dir, "index.ann"))

    def get_item_vector(self, id):
        return self.index.get_item_vector(id)

    def search(self, embedding, include_distances=False, n=20):
        return self.index.get_nns_by_vector(
            embedding, n, search_k=-1, include_distances=include_distances)
コード例 #19
0
ファイル: index_test.py プロジェクト: zzszmyf/annoy
 def test_load_save(self):
     # Issue #61
     i = AnnoyIndex(10, 'angular')
     i.load('test/test.tree')
     u = i.get_item_vector(99)
     i.save('i.tree')
     v = i.get_item_vector(99)
     self.assertEqual(u, v)
     j = AnnoyIndex(10, 'angular')
     j.load('test/test.tree')
     w = i.get_item_vector(99)
     self.assertEqual(u, w)
     # Ensure specifying if prefault is allowed does not impact result
     j.save('j.tree', True)
     k = AnnoyIndex(10, 'angular')
     k.load('j.tree', True)
     x = k.get_item_vector(99)
     self.assertEqual(u, x)
     k.save('k.tree', False)
     l = AnnoyIndex(10, 'angular')
     l.load('k.tree', False)
     y = l.get_item_vector(99)
     self.assertEqual(u, y)
コード例 #20
0
ファイル: dot_index_test.py プロジェクト: spotify/annoy
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f, 'dot')
     for j in range(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, numpy.dot(
                 i.get_item_vector(a),
                 i.get_item_vector(b)
             ))
             self.assertEqual(dist, i.get_distance(a, b))
コード例 #21
0
ファイル: index_test.py プロジェクト: spotify/annoy
 def test_load_save(self):
     # Issue #61
     i = AnnoyIndex(10)
     i.load('test/test.tree')
     u = i.get_item_vector(99)
     i.save('i.tree')
     v = i.get_item_vector(99)
     self.assertEqual(u, v)
     j = AnnoyIndex(10)
     j.load('test/test.tree')
     w = i.get_item_vector(99)
     self.assertEqual(u, w)
     # Ensure specifying if prefault is allowed does not impact result
     j.save('j.tree', True)
     k = AnnoyIndex(10)
     k.load('j.tree', True)
     x = k.get_item_vector(99)
     self.assertEqual(u, x)
     k.save('k.tree', False)
     l = AnnoyIndex(10)
     l.load('k.tree', False)
     y = l.get_item_vector(99)
     self.assertEqual(u, y)
コード例 #22
0
ファイル: manhattan_index_test.py プロジェクト: spotify/annoy
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f, 'manhattan')
     for j in range(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, i.get_distance(a, b))
             u = numpy.array(i.get_item_vector(a))
             v = numpy.array(i.get_item_vector(b))
             self.assertAlmostEqual(dist, numpy.sum(numpy.fabs(u - v)))
             self.assertAlmostEqual(dist, sum([abs(float(x)-float(y)) for x, y in zip(u, v)]))
コード例 #23
0
ファイル: evaluate.py プロジェクト: ajmarcus/nosranet
class KNN(object):
    def __init__(self, label: Label, model: Model):
        self.tree = AnnoyIndex(SIZE[model], "angular")
        self.tree.load(TREE[model][label])

    def nearest_index(self, y_pred: np.ndarray) -> int:
        return self.tree.get_nns_by_vector(vector=y_pred.tolist(), n=1)[0]

    def nearest(self, y_pred: np.ndarray) -> np.ndarray:
        index = self.nearest_index(y_pred=y_pred)
        return np.asarray(self.tree.get_item_vector(index))

    def distance(self, left_index: int, right_index: int) -> float:
        return self.tree.get_distance(left_index, right_index)
コード例 #24
0
ファイル: annoy_test.py プロジェクト: wangxggc/annoy
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f, 'manhattan')
     for j in xrange(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, i.get_distance(a, b))
             u = numpy.array(i.get_item_vector(a))
             v = numpy.array(i.get_item_vector(b))
             self.assertAlmostEqual(dist, numpy.sum(numpy.fabs(u - v)))
             self.assertAlmostEqual(dist, sum([abs(float(x)-float(y)) for x, y in zip(u, v)]))
コード例 #25
0
def main():
    t = AnnoyIndex(200, metric='euclidean')
    lines = list()
    lookup = dict()

    print("loading...")
    index = 0
    for row in open("phonetic_vectors_every2_d200_reformatted.txt"):
        spl = row.find("@@@")
        line = row[0:spl - 1]
        stripped_line = line[2:-1].lower()  #skip the b''
        vec = row[spl + 3:-1]
        vals = np.array([float(val) for val in vec.split(", ")])
        if stripped_line in lookup:
            continue
        lookup[stripped_line] = index
        lines.append(stripped_line)
        t.add_item(index, vals)
        index += 1
        if index % 50000 == 0:
            print(stripped_line.lower())
            print("{0} vectors loaded".format(index))
    t.build(100)
    print("done.")

    print("Num dict items: {0}".format(len(lookup)))
    print("Num list items: {0}".format(len(lines)))
    print("Num index items: {0}".format(t.get_n_items()))

    try:
        vec = lookup["skating on thin ice"]
        print(vec)
        print(t.get_item_vector(vec))
        print(nn_lookup(t, t.get_item_vector(vec)))
        print([lines[i[0]] for i in nn_lookup(t, t.get_item_vector(vec))])
    except KeyError:
        print("not found")
コード例 #26
0
ファイル: annoy_test.py プロジェクト: wangxggc/annoy
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f, 'euclidean')
     for j in xrange(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, i.get_distance(a, b))
             u = numpy.array(i.get_item_vector(a))
             v = numpy.array(i.get_item_vector(b))
             # self.assertAlmostEqual(dist, euclidean(u, v))
             self.assertAlmostEqual(dist, numpy.dot(u - v, u - v) ** 0.5)
             self.assertAlmostEqual(dist, sum([(x-y)**2 for x, y in zip(u, v)])**0.5)
コード例 #27
0
ファイル: hamming_index_test.py プロジェクト: spotify/annoy
 def test_basic_conversion(self):
     f = 100
     i = AnnoyIndex(f, 'hamming')
     u = numpy.random.binomial(1, 0.5, f)
     v = numpy.random.binomial(1, 0.5, f)
     i.add_item(0, u)
     i.add_item(1, v)
     u2 = i.get_item_vector(0)
     v2 = i.get_item_vector(1)
     self.assertAlmostEqual(numpy.dot(u - u2, u - u2), 0.0)
     self.assertAlmostEqual(numpy.dot(v - v2, v - v2), 0.0)
     self.assertAlmostEqual(i.get_distance(0, 0), 0.0)
     self.assertAlmostEqual(i.get_distance(1, 1), 0.0)
     self.assertAlmostEqual(i.get_distance(0, 1), numpy.dot(u - v, u - v))
     self.assertAlmostEqual(i.get_distance(1, 0), numpy.dot(u - v, u - v))
コード例 #28
0
ファイル: hamming_index_test.py プロジェクト: MeggyCal/annoy
 def test_basic_conversion(self):
     f = 100
     i = AnnoyIndex(f, 'hamming')
     u = numpy.random.binomial(1, 0.5, f)
     v = numpy.random.binomial(1, 0.5, f)
     i.add_item(0, u)
     i.add_item(1, v)
     u2 = i.get_item_vector(0)
     v2 = i.get_item_vector(1)
     self.assertAlmostEqual(numpy.dot(u - u2, u - u2), 0.0)
     self.assertAlmostEqual(numpy.dot(v - v2, v - v2), 0.0)
     self.assertAlmostEqual(i.get_distance(0, 0), 0.0)
     self.assertAlmostEqual(i.get_distance(1, 1), 0.0)
     self.assertAlmostEqual(i.get_distance(0, 1), numpy.dot(u - v, u - v))
     self.assertAlmostEqual(i.get_distance(1, 0), numpy.dot(u - v, u - v))
コード例 #29
0
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f, 'euclidean')
     for j in xrange(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, i.get_distance(a, b))
             u = numpy.array(i.get_item_vector(a))
             v = numpy.array(i.get_item_vector(b))
             # self.assertAlmostEqual(dist, euclidean(u, v))
             self.assertAlmostEqual(dist, numpy.dot(u - v, u - v) ** 0.5)
             self.assertAlmostEqual(dist, sum([(x-y)**2 for x, y in zip(u, v)])**0.5)
コード例 #30
0
class ChexSearch(object):
    """ Searches Chex index for game states and associated games. """

    #TODO: Combine results of board transforms with binary search algo.

    def __init__(self, chex_index, results=10, search_k=40):
        self.chex_index = chex_index
        self.results = results
        self.search_k = search_k
        self.annoy_index = AnnoyIndex(_bitboard_length, metric='angular')
        self.annoy_index.load(os.path.join(self.chex_index, 'annoy.idx'))
        self.chex_sql = SqliteDict(os.path.join(self.chex_index, 'sqlite.idx'))

    def search(self, board):
        """ Searches for board.

            board: game object of type chess.Board

            Return value: [
                (board, similarity score, [(game_id, move number), ...]), ...]
        """

        symmetrical_boards = [
            board_to_bitboard(board),
            invert_board(board),
            flip_board(board),
            reverse_and_flip(board)
        ]
        results = []
        for bitboard in symmetrical_boards:
            for annoy_id, similarity in zip(
                    *self.annoy_index.get_nns_by_vector(
                        bitboard, self.results, include_distances=True)):
                # Recompute ASCII key
                bitboard = self.annoy_index.get_item_vector(annoy_id)
                to_unhexlify = '%x' % int(
                    ''.join(map(str, map(int, bitboard))), 2)
                try:
                    key = binascii.unhexlify(to_unhexlify)
                except TypeError:
                    key = binascii.unhexlify('0' + to_unhexlify)
                results.append((bitboard_to_board(bitboard), similarity,
                                self.chex_sql[key]))
        return results

    def close(self):
        del self.annoy_index
コード例 #31
0
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f)
     for j in xrange(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, i.get_distance(a, b))
             u = i.get_item_vector(a)
             v = i.get_item_vector(b)
             u_norm = numpy.array(u) * numpy.dot(u, u)**-0.5
             v_norm = numpy.array(v) * numpy.dot(v, v)**-0.5
             # cos = numpy.clip(1 - cosine(u, v), -1, 1) # scipy returns 1 - cos
             self.assertAlmostEqual(dist, numpy.dot(u_norm - v_norm, u_norm - v_norm) ** 0.5)
             # self.assertAlmostEqual(dist, (2*(1 - cos))**0.5)
             self.assertAlmostEqual(dist, sum([(x-y)**2 for x, y in zip(u_norm, v_norm)])**0.5)
コード例 #32
0
ファイル: annoy_test.py プロジェクト: wangxggc/annoy
 def test_distance_consistency(self):
     n, f = 1000, 3
     i = AnnoyIndex(f)
     for j in xrange(n):
         i.add_item(j, numpy.random.normal(size=f))
     i.build(10)
     for a in random.sample(range(n), 100):
         indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
         for b, dist in zip(indices, dists):
             self.assertAlmostEqual(dist, i.get_distance(a, b))
             u = i.get_item_vector(a)
             v = i.get_item_vector(b)
             u_norm = numpy.array(u) * numpy.dot(u, u)**-0.5
             v_norm = numpy.array(v) * numpy.dot(v, v)**-0.5
             # cos = numpy.clip(1 - cosine(u, v), -1, 1) # scipy returns 1 - cos
             self.assertAlmostEqual(dist, numpy.dot(u_norm - v_norm, u_norm - v_norm) ** 0.5)
             # self.assertAlmostEqual(dist, (2*(1 - cos))**0.5)
             self.assertAlmostEqual(dist, sum([(x-y)**2 for x, y in zip(u_norm, v_norm)])**0.5)
コード例 #33
0
class DualEncoderSearcher(Searcher):
    def __init__(self, args):
        super(DualEncoderSearcher, self).__init__(args)
        self.infer_batch = args.batch(self.tokenizer, args.max_lens)
        self.model = args.model(args)
        self.ann = AnnoyIndex(args.hidden)
        self.ann.load(args.path['ann'])
        self.sess = tf.Session()
        saver = tf.train.Saver()
        saver.restore(self.sess, args.path['model'])

    def search_line(self, line, num=15):
        input_x = self.infer_batch.encode_x(line)
        infer_features = {'input_x_ph': [input_x], 'keep_prob_ph': 1.0}
        infer_fetches, infer_feed = self.model.infer_step(infer_features)
        vec = self.sess.run(infer_fetches, infer_feed)[0][0]
        ids = self.ann.get_nns_by_vector(vec, num)
        vecs = [self.ann.get_item_vector(i) for i in ids]
        sim = [utils.cosine_similarity(vec, i) for i in vecs]
        return list(zip(ids, sim))
コード例 #34
0
        ann_indexer.add_item(i, item)

    ann_indexer.build(10)

    hs, cs = 0, 0
    ann_enn = np.zeros((patterns, 2))

    for i in range(patterns):
        mpi_nn = mpi[candidates[i]][0]

        c = candidates_matrix[i]
        quan_mpi_nn = discretization(data[mpi_nn:mpi_nn + window_size], t_min,
                                     t_max, bits)
        hypo_rdl = rdl(quan_mpi_nn, c, bits)

        enn = exact_1nn(ann_indexer, patterns, ann_indexer.get_item_vector(i),
                        i)[0]

        ann_nn = ann_indexer.get_nns_by_item(i, 2)
        if len(ann_nn) == 0:
            print('Approximate nn not found')
            hypo_bitsave = window_size * bits - hypo_rdl
            hs += 1
            print('{} is {}, saving {}'.format(i, 'hypothesis', hypo_bitsave))
            continue
        else:
            ann_nn = ann_nn[1]
            print('{}, ANN Pair {}, ENN Pair {}'.format(i, ann_nn, enn))

        ann_enn[i] = np.array([ann_nn, enn]).reshape(1, 2)
コード例 #35
0
from annoy import AnnoyIndex

raw_input("<Enter> to create tree.")
tree = AnnoyIndex(500)

raw_input("<Enter> to load tree.")
tree.load("test_tree.ann")

raw_input("<Enter> to load 10,000 vectors.")
q = True
while q:

    for i in xrange(10000):
        tree.get_item_vector(i)

    resp = raw_input("<Enter> to load 10,000 vectors.")
    if resp.strip() == "q":
        q = False

raw_input("<Enter> to unload tree.")

tree.unload("test_tree.ann")

raw_input("done.")
tree.load("test_tree.ann")

raw_input("<Enter> to load 10,000 vectors.")
q = True
while q:

    for i in xrange(10000):
コード例 #36
0
class AnnoySearch:
    def __init__(self,
                 vec_dim=2048,
                 lmdb_file="static/lmdb",
                 ann_file="static/annoy_file/tree.ann",
                 metric='angular',
                 num_trees=10):
        self.vec_dim = vec_dim  # 要index的向量维度
        self.metric = metric  # 度量可以是"angular","euclidean","manhattan","hamming",或"dot"
        self.annoy_instance = AnnoyIndex(self.vec_dim, self.metric)
        self.lmdb_file = lmdb_file
        self.ann_file = ann_file
        self.num_trees = num_trees
        self.logger = logging.getLogger('AnnoySearch')

    def save_annoy(self):
        self.annoy_instance.save(self.ann_file)
        self.logger.info('save annoy SUCCESS !')

    def unload_annoy(self):
        self.annoy_instance.unload()

    def load_annoy(self):
        try:
            self.annoy_instance.unload()
            self.annoy_instance.load(self.ann_file)
            self.logger.info('load annoy SUCCESS !')
        except FileNotFoundError:
            self.logger.error(
                'annoy file DOES NOT EXIST , load annoy FAILURE !',
                exc_info=True)
        # 创建annoy索引

    def create_index_from_lmdb(self):
        # 遍历
        lmdb_file = self.lmdb_file
        if os.path.isdir(lmdb_file):
            evn = lmdb.open(lmdb_file)
            wfp = evn.begin()
            for key, value in wfp.cursor():
                key = int(key)
                value = str2embed(value)
                print(len(value))
                self.annoy_instance.add_item(key, value)

            self.annoy_instance.build(self.num_trees)
            self.annoy_instance.save(self.ann_file)

    def build_annoy(self):
        self.annoy_instance.build(self.num_trees)

    def get_nns_by_item(self,
                        index,
                        nn_num,
                        search_k=-1,
                        include_distances=False):
        return self.annoy_instance.get_nns_by_item(index, nn_num, search_k,
                                                   include_distances)

    def get_nns_by_vector(self,
                          vec,
                          nn_num,
                          search_k=-1,
                          include_distances=False):
        return self.annoy_instance.get_nns_by_vector(vec, nn_num, search_k,
                                                     include_distances)

    def get_n_items(self):
        return self.annoy_instance.get_n_items()

    def get_n_trees(self):
        return self.annoy_instance.get_n_trees()

    def get_vec_dim(self):
        return self.vec_dim

    def add_item(self, index, vec):
        self.annoy_instance.add_item(index, vec)

    def get_item_vector(self, index):
        return self.annoy_instance.get_item_vector(index)
コード例 #37
0
def margin_base_score(source_document, source_weights, target_lang):
    if (target_lang == 'si'):
        f = 1024
        u = AnnoyIndex(f, 'euclidean')
        u.load('../index/test_si.ann')

        map_file = open('../index/sent_to_doc_map_si.json', encoding='utf8')
        maps = json.load(map_file)

        cmap_file = open('../index/sent_count_map_en.json', encoding='utf8')
        cmaps = json.load(cmap_file)
    else:
        f = 1024
        u = AnnoyIndex(f, 'euclidean')
        u.load('../index/test_en.ann')
        map_file = open('../index/sent_to_doc_map_en.json', encoding='utf8')
        maps = json.load(map_file)

        cmap_file = open('../index/sent_count_map_en.json', encoding='utf8')
        cmaps = json.load(cmap_file)

    k = 3
    scores = {}
    dict_s = {}
    dict_t = {}

    for i in range(len(source_document)):
        all_docs = []
        source_docs = source_document[i]
        source_weight = source_weights[i]
        for embed in source_docs:
            lst = u.get_nns_by_vector(embed, 10, search_k=100000)
            for sent in lst:
                all_docs.append(maps[str(sent)])

        counts = collections.Counter(all_docs)
        new_list = sorted(all_docs, key=counts.get, reverse=True)

        all_docs = []
        for y in new_list:
            if len(all_docs) == 20:
                break
            if y not in all_docs:
                all_docs.append(y)

        for j in all_docs:
            target_weight = np.load('../db/' + str(j // 1000) + '/w' +
                                    target_lang + '/' + str(j % 1000) + '.npy')

            # add read embedding from ann here
            cc = int(cmaps[str(j)])
            target_docs = []
            for kk in range(len(target_weight)):
                target_docs.append(u.get_item_vector(cc + kk))

            # below line for read from .npy embd files
            # target_docs = np.load('../db/' + str(j // 1000) + '/'+target_lang+'/' + str(j % 1000) + '.npy')

            distance = greedy_mover_distance(source_document[i].copy(),
                                             target_docs.copy(),
                                             source_weights[i].copy(),
                                             target_weight.copy())
            # distance = s[(i,j)]
            scores[(i, j)] = distance

            if (i in dict_s):
                kneighbours = dict_s[i]
                if (len(kneighbours) < k):
                    kneighbours.append(distance)
                else:
                    max_distance = max(kneighbours)
                    if (distance < max_distance):
                        kneighbours.append(distance)
                        kneighbours.remove(max_distance)
                dict_s[i] = kneighbours
            else:
                dict_s[i] = [distance]

            if (j in dict_t):
                kneighbours = dict_t[j]
                if (len(kneighbours) < k):
                    kneighbours.append(distance)
                else:
                    max_distance = max(kneighbours)
                    if (distance < max_distance):
                        kneighbours.append(distance)
                        kneighbours.remove(max_distance)
                dict_t[j] = kneighbours
            else:
                dict_t[j] = [distance]
    for pair in scores:
        score = get_score(dict_s, dict_t, pair, scores, k)
        scores[pair] = score
    return scores
コード例 #38
0
for i, token in enumerate(keywords):
    t.add_item(i, token.embedding)

t.build(100)

if __name__ == "__main__":
    start = time()
    found_match = False

    for i, token in enumerate(sentence):
        match = t.get_nns_by_vector(token.embedding, 1)[0]
        sim_score = float(
            cos(
                token.embedding.view(-1, EMBEDDING_SIZE),
                torch.tensor(t.get_item_vector(match)).view(
                    -1, EMBEDDING_SIZE)))
        print(sim_score, match, token)
        if 0.6 <= sim_score <= 0.8:
            print(f'Found Counterfeit')
            print(
                f"Found Counterfeit {token} with a {round(sim_score, 2)} similarity to brand {keywords[match]}"
            )
            found_match = True

        if found_match:
            break

    end = time()
    print(end - start)
コード例 #39
0
def ANN(fv, id_list, pkl_dir='output/', dest_dir='output/', mode="batch"):
    #start ann
    if not mode in ["batch", "recent"]:
        print "[Error] Mode error!"
        exit()

    print("[" + mode + "] Derive related news....")
    n = len(fv)
    f = len(fv[0])
    print("n=" + str(n) + ", f=" + str(f) + "\n")

    print("Making Indexing Trees...")
    t = AnnoyIndex(f)  # Length of item vector that will be indexed
    for i in range(n):
        v = fv[i]
        t.add_item(i, v)
        if i >= 500 and i % 500 == 0:
            print("Added...." + str(i))
    print("Build Indexing Trees....")
    t.build(5)

    # store the indexing tree
    tree_name = 'news-indexing-tree.ann'
    if mode == 'recent':
        tree_name = 'recent-' + tree_name
    t.save(dest_dir + tree_name)
    print("Save indexing tree: " + dest_dir + tree_name)

    # default generate 20 related news for each article
    # the result will be put into "output_filename"
    k = 20
    output_filename = 'mirror-news-ann-distance-20.result'

    # t: the indexing tree for all data, t: the current indexing tree
    if mode == "batch":
        pass
    elif mode == "recent":
        # 1. setting the output name
        output_filename = 'recent-' + output_filename

        # 2. loading the indexing tree and id list of all data
        u = AnnoyIndex(f)
        # 2.1. Load indexing tree from all data
        if os.path.exists(dest_dir + "news-indexing-tree.ann"):
            print "Get the indexing tree: " + dest_dir + "news-indexing-tree.ann"
            u.load(dest_dir + "news-indexing-tree.ann")
        else:
            print "[Error] File does not exist:" + dest_dir + "news-indexing-tree.ann"
            print "Run ANN with batch mode first"
            exit()
        # 2.2. Load id list from all data
        if os.path.exists(pkl_dir + "id_list_all.pkl"):
            f_pkl = open(pkl_dir + "id_list_all.pkl", 'r')
            id_list_all = Pickle.load(f_pkl)
        else:
            print "Failed to load: " + pkl_dir + "id_list_all.pkl"
            print "Load fallback id list"
            f_pkl = open("fallback/fallback-id-list-all.pkl", "r")
            id_list_all = Pickle.load(f_pkl)

    g = open(dest_dir + output_filename, 'w')
    pre_t = time.time()
    # generate a list for related news
    for i in range(n):
        news_id = id_list[i]

        knn_news = t.get_nns_by_item(i, k + 1, include_distances=True)
        knn_list = knn_news[0]
        dist_list = knn_news[1]
        del (knn_list[0])
        del (dist_list[0])
        related_news = [(id_list[knn_list[j]], dist_list[j])
                        for j in range(len(knn_list))]

        if mode == "recent":
            vi = t.get_item_vector(i)
            knn_news_all = u.get_nns_by_vector(v, k, include_distances=True)
            knn_list_all = knn_news_all[0]
            dist_list_all = knn_news_all[1]
            related_news_all = [(id_list_all[knn_list[j]], dist_list[j])
                                for j in range(len(knn_list))]
            # overwrite related_news

            for x in related_news_all:
                if not x[0] in knn_list:
                    related_news.append(x)
            # sort according to score
            related_news = sorted(related_news, key=lambda x: x[1])[0:k]

        related_news_json = json.dumps(related_news)
        g.write(news_id + "\t" + related_news_json + "\n")
        if i % 100 == 0:
            print("Processed:" + str(i) + ", time passed:" +
                  str(time.time() - pre_t) + "(s)")
            pre_t = time.time()
    print "The related news are in: " + dest_dir + output_filename
    g.close()
コード例 #40
0
    known_face_names.append(name)
    face_names = []
while True:
    ret, frame = video_capture.read()
    small_frame = cv2.resize(frame, (0, 0), fx=0.25, fy=0.25)
    rgb_small_frame = small_frame[:, :, ::-1]

   
    face_names = []
        
    face_locations = face_recognition.face_locations(rgb_small_frame)
    face_encodings = face_recognition.face_encodings(rgb_small_frame, face_locations)

    for face_encoding in face_encodings:
        matches_id = u.get_nns_by_vector(face_encoding, 1)[0]
        known_face_encoding = u.get_item_vector(matches_id)
        compare_faces = face_recognition.compare_faces([known_face_encoding], face_encoding)
        name = "unknown"

        if compare_faces[0]:
            name = known_face_names[matches_id]
        face_names.append(name)
        print(face_names)
        
    for (top, right, bottom, left), name in zip(face_locations, face_names):
         top *= 4
        right *= 4
        bottom *= 4
        left *= 4
        cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), 2)
        cv2.rectangle(frame, (left, bottom - 35), (right, bottom), (0, 0, 255), cv2.FILLED)
コード例 #41
0
ファイル: tmp.py プロジェクト: edonyM/annoy
from annoy import AnnoyIndex
import random

f = 40
t = AnnoyIndex(f)  # Length of item vector that will be indexed
for i in xrange(1000):
    v = [random.gauss(0, 1) for z in xrange(f)]
    t.add_item(i, v)

t.build(10) # 10 trees
t.save('test.ann')

# ...

u = AnnoyIndex(f)
u.load('test.ann') # super fast, will just mmap the file
print(u.get_nns_by_item(0, 1000)) # will find the 1000 nearest neighbors
item = u.get_item_vector(0)
print(u.get_nns_by_vector(item, 1000)) # will find the 1000 nearest neighbors
#print(len(u.get_nns_by_vector(item, 1000)))
#print(len(set(u.get_nns_by_vector(item, 1000))))
#print(len(u.get_nns_by_item(0, 1000)))
#print(len(set(u.get_nns_by_item(0, 1000))))
#if u.get_nns_by_vector(item, 1000) == u.get_nns_by_item(0, 1000):
#    print("SAME\n")
コード例 #42
0
ファイル: annoyVectorIndex.py プロジェクト: zggl/Seq2Seq-Vis
class AnnoyVectorIndex:
    def __init__(self, file_name, dim_vector=500):
        self.u = AnnoyIndex(dim_vector)
        self.u.load(file_name)

    def get_closest(self,
                    ix,
                    k=10,
                    ignore_same_tgt=False,
                    include_distances=False,
                    use_vectors=False):
        if ignore_same_tgt:
            interval_min = ix // 55 * 55
            if use_vectors:
                candidates = self.u.get_nns_by_vector(
                    ix,
                    k + 55,
                    search_k=100000,
                    include_distances=include_distances)
            else:
                candidates = self.u.get_nns_by_item(
                    ix,
                    k + 55,
                    search_k=100000,
                    include_distances=include_distances)
            if include_distances:
                return [
                    k for k in zip(*candidates)
                    if not interval_min <= k[0] <= interval_min + 55
                ][:k]
            else:
                return [
                    k for k in candidates
                    if not interval_min <= k <= interval_min + 55
                ][:k]
        else:
            if use_vectors:
                return list(
                    zip(*self.u.get_nns_by_vector(
                        ix,
                        k,
                        search_k=100000,
                        include_distances=include_distances)))

            else:
                return list(
                    zip(*self.u.get_nns_by_item(
                        ix,
                        k,
                        search_k=100000,
                        include_distances=include_distances)))

    def get_closest_x(self,
                      ixs,
                      k=10,
                      ignore_same_tgt=False,
                      include_distances=False,
                      use_vectors=False):
        res = []
        for ix in ixs:
            res.append(
                self.get_closest(ix, k, ignore_same_tgt, include_distances,
                                 use_vectors))
        return res

    def get_details(self, ixs):
        res = []
        for ix in ixs:
            res.append({
                'index': ix,
                'v': self.u.get_item_vector(ix),
                'pos': self.search_to_sentence_index(ix)
            })

        return res

    def get_vectors(self, ixs):
        return map(lambda x: self.u.get_item_vector(x), ixs)

    def get_vector(self, ix):
        return self.u.get_item_vector(ix)

    def search_to_sentence_index(self, index):
        return index // 55, index % 55

    def sentence_to_search_index(self, sentence, pos_in_sent):
        return sentence * 55 + pos_in_sent
コード例 #43
0
ファイル: testAnnoy.py プロジェクト: Gr4ni/Filmdings
#u.load('test.ann') # super fast, will just mmap the file

imdb_id_sr = '0111161' # shawshank redemption
imdb_id_gf1 = '0068646' 
imdb_id_gf2 = '0071562'
imdb_id_ts3 = '0435761'
imdb_id_fpr = '0120815'
imdb_id_fn = '0266543'

# find 10 closest matches
for ann_index in t.get_nns_by_item(imdb_to_index[imdb_id_ts3], 10):
    imdb_id = index_to_imdb[ann_index]
    
    movie_title = cdb.get(str(imdb_id))['title']
    
    print movie_title, t.get_item_vector(ann_index)
    
# distances between some movies
print "GF1 <-> GF2", t.get_distance(imdb_to_index[imdb_id_gf1], imdb_to_index[imdb_id_gf2])
print "GF1 <-> SR", t.get_distance(imdb_to_index[imdb_id_gf1], imdb_to_index[imdb_id_sr])
print "GF2 <-> SR", t.get_distance(imdb_to_index[imdb_id_gf2], imdb_to_index[imdb_id_sr])
print "GF1 <-> TS3", t.get_distance(imdb_to_index[imdb_id_gf1], imdb_to_index[imdb_id_ts3])
print "FPR <-> FN", t.get_distance(imdb_to_index[imdb_id_fpr], imdb_to_index[imdb_id_fn])


with open('closest_matches.csv', 'wb') as f:
    writer = csv.writer(f)
    # find the closest match for each movie
    for imdb_id in imdb_to_index:
        closest_match = t.get_nns_by_item(imdb_to_index[imdb_id], 2)[1]
        distance = t.get_distance(imdb_to_index[imdb_id], closest_match)
コード例 #44
0
def main(_):
    parser = argparse.ArgumentParser(description='TransE.')
    parser.add_argument('--data',
                        dest='data_dir',
                        type=str,
                        help="Data folder",
                        default='./data/FB15k/')
    parser.add_argument('--lr',
                        dest='lr',
                        type=float,
                        help="Learning rate",
                        default=1e-2)
    parser.add_argument("--dim",
                        dest='dim',
                        type=int,
                        help="Embedding dimension",
                        default=256)
    parser.add_argument("--batch",
                        dest='batch',
                        type=int,
                        help="Batch size",
                        default=32)
    parser.add_argument("--worker",
                        dest='n_worker',
                        type=int,
                        help="Evaluation worker",
                        default=3)
    parser.add_argument("--generator",
                        dest='n_generator',
                        type=int,
                        help="Data generator",
                        default=10)
    parser.add_argument("--eval_batch",
                        dest="eval_batch",
                        type=int,
                        help="Evaluation batch size",
                        default=32)
    parser.add_argument("--save_dir",
                        dest='save_dir',
                        type=str,
                        help="Model path",
                        default='./transE')
    parser.add_argument("--load_model",
                        dest='load_model',
                        type=str,
                        help="Model file",
                        default="")
    parser.add_argument("--save_per",
                        dest='save_per',
                        type=int,
                        help="Save per x iteration",
                        default=1)
    parser.add_argument("--eval_per",
                        dest='eval_per',
                        type=int,
                        help="Evaluate every x iteration",
                        default=1)
    parser.add_argument("--max_iter",
                        dest='max_iter',
                        type=int,
                        help="Max iteration",
                        default=30)
    parser.add_argument("--summary_dir",
                        dest='summary_dir',
                        type=str,
                        help="summary directory",
                        default='./transE_summary/')
    parser.add_argument("--keep",
                        dest='drop_out',
                        type=float,
                        help="Keep prob (1.0 keep all, 0. drop all)",
                        default=0.5)
    parser.add_argument("--optimizer",
                        dest='optimizer',
                        type=str,
                        help="Optimizer",
                        default='gradient')
    parser.add_argument("--prefix",
                        dest='prefix',
                        type=str,
                        help="model_prefix",
                        default='DEFAULT')
    parser.add_argument("--loss_weight",
                        dest='loss_weight',
                        type=float,
                        help="Weight on parameter loss",
                        default=1e-2)
    parser.add_argument("--neg_weight",
                        dest='neg_weight',
                        type=float,
                        help="Sampling weight on negative examples",
                        default=0.5)
    parser.add_argument("--save_per_batch",
                        dest='save_per_batch',
                        type=int,
                        help='evaluate and save after every x batches',
                        default=1000)
    parser.add_argument(
        "--outfile_prefix",
        dest='outfile_prefix',
        type=str,
        help='The filename of output file is outfile_prefix.txt',
        default='test_output')
    parser.add_argument("--neg_sample",
                        dest='neg_sample',
                        type=int,
                        help='No. of neg. samples per (h,r) or (t,r) pair',
                        default=5)
    parser.add_argument(
        "--fanout_thresh",
        dest='fanout_thresh',
        type=int,
        help='threshold on fanout of entities to be considered',
        default=2)
    parser.add_argument('--annoy_n_trees',
                        dest='annoy_n_trees',
                        type=int,
                        help='builds a forest of n_trees trees',
                        default=10)
    parser.add_argument(
        '--annoy_search_k',
        dest='annoy_search_k',
        type=int,
        help='During the query it will inspect up to search_k nodes',
        default=-1)
    parser.add_argument('--eval_after',
                        dest='eval_after',
                        type=int,
                        help='Evaluate after this many no. of epochs',
                        default=0)

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    print(args)

    model = TransE(args.data_dir,
                   embed_dim=args.dim,
                   fanout_thresh=args.fanout_thresh,
                   eval_batch=args.eval_batch)

    train_pos_neg_list, \
    train_loss, train_op = train_ops(model, learning_rate=args.lr,
                                     optimizer_str=args.optimizer,
                                     regularizer_weight=args.loss_weight)

    get_embedding_op = embedding_ops(model)

    # test_input, test_head, test_tail = test_ops(model)
    # f1 = open('%s/%s.txt' % (args.save_dir, args.outfile_prefix),'w')

    with tf.Session() as session:
        tf.global_variables_initializer().run()

        all_var = tf.all_variables()
        print 'printing all', len(all_var), ' TF variables:'
        for var in all_var:
            print var.name, var.get_shape()

        saver = tf.train.Saver(restore_sequentially=True)

        iter_offset = 0

        if args.load_model is not None and os.path.exists(args.load_model):
            saver.restore(session, args.load_model)
            iter_offset = int(
                args.load_model.split('.')[-2].split('_')[-1]) + 1
            print("Load model from %s, iteration %d restored.\n" %
                  (args.load_model, iter_offset))

        total_inst = model.n_train
        best_filtered_mean_rank = float("inf")

        print("preparing training data...\n")
        nbatches_count = 0
        # training_data_list = []
        training_data_pos_neg_list = []

        for dat in model.raw_training_data(batch_size=args.batch):
            # raw_training_data_queue.put(dat)
            # training_data_list.append(dat)
            ps_list = data_generator_func(dat, model.tr_h, model.hr_t,
                                          model.n_entity, args.neg_sample,
                                          model.n_relation)
            assert ps_list is not None
            training_data_pos_neg_list.append(ps_list)
            nbatches_count += 1
        print("training data prepared.\n")
        print("No. of batches : %d\n" % nbatches_count)
        # f1.close()

        start_time = timeit.default_timer()

        for n_iter in range(iter_offset, args.max_iter):
            accu_loss = 0.
            ninst = 0
            # f1.close()

            for batch_id in range(nbatches_count):
                # f1 = open('%s/%s.txt' % (args.save_dir, args.outfile_prefix),'a')

                pos_neg_list = training_data_pos_neg_list[batch_id]
                #print data_e
                l, _ = session.run([train_loss, train_op],
                                   {train_pos_neg_list: pos_neg_list})

                accu_loss += l
                ninst += len(pos_neg_list)

                # print('len(pos_neg_list) = %d\n' % len(pos_neg_list))

                if ninst % (5000) is not None:
                    print('[%d sec](%d/%d) : %.2f -- loss : %.5f \n' %
                          (timeit.default_timer() - start_time, ninst,
                           total_inst, float(ninst) / total_inst, l))
                # f1.close()

            # f1 = open('%s/%s.txt' % (args.save_dir, args.outfile_prefix),'a')
            print("")
            print("iter %d avg loss %.5f, time %.3f\n" %
                  (n_iter, accu_loss / ninst,
                   timeit.default_timer() - start_time))

            # if n_iter == args.max_iter - 1:
            #     save_path = saver.save(session,
            #                            os.path.join(args.save_dir,
            #                                         "TransE_" + str(args.prefix) + "_" + str(n_iter) + ".ckpt"))
            #     print("Model saved at %s\n" % save_path)

            with tf.device('/cpu'):
                if n_iter > args.eval_after and (n_iter % args.eval_per == 0 or
                                                 n_iter == args.max_iter - 1):

                    t = AnnoyIndex(model.embed_dim, metric='euclidean')

                    ent_embedding, rel_embedding = session.run(
                        get_embedding_op, {train_pos_neg_list: pos_neg_list})
                    # sess = tf.InteractiveSession()
                    # with sess.as_default():
                    #     ent_embedding = model.ent_embeddings.eval()
                    print np.asarray(ent_embedding).shape
                    print np.asarray(rel_embedding).shape

                    # print ent_embedding[10,:]
                    # print rel_embedding[10,:]
                    print 'Index creation started'
                    st_1 = timeit.default_timer()

                    for i in xrange(model.n_entity):
                        v = ent_embedding[i, :]
                        t.add_item(i, v)
                    t.build(args.annoy_n_trees)

                    print 'Index creation completed time taken = %f' % (
                        timeit.default_timer() - st_1)

                    # n = int(0.0005 * model.n_entity)
                    n = 1000
                    # search_k = int(n * args.annoy_n_trees/100.0)
                    search_k = 1000

                    print 'No. of items = %d' % t.get_n_items()
                    print sum(t.get_item_vector(0))
                    print sum(ent_embedding[0, :])
                    assert sum(t.get_item_vector(0)) == sum(
                        ent_embedding[0, :])

                    if n_iter == args.max_iter - 1:
                        eval_dict = zip(
                            [model.validation_data, model.testing_data],
                            ['VALID', 'TEST'])
                    else:
                        eval_dict = zip([model.validation_data], ['VALID'])

                    for data_func, test_type in eval_dict:
                        accu_mean_rank_h = list()
                        accu_mean_rank_t = list()
                        accu_filtered_mean_rank_h = list()
                        accu_filtered_mean_rank_t = list()

                        evaluation_count = 0
                        evaluation_batch = []
                        batch_id = 0

                        st = timeit.default_timer()
                        for testing_data in data_func(
                                batch_size=args.eval_batch):
                            batch_id += 1
                            print 'test_type: %s, batch id: %d' % (test_type,
                                                                   batch_id)
                            head_ids = list()
                            tail_ids = list()

                            for i in xrange(testing_data.shape[0]):
                                # try:
                                # print (ent_embedding[testing_data[i,0],:] + rel_embedding[testing_data[i,2],:])
                                head_ids.append(
                                    t.get_nns_by_vector(
                                        (ent_embedding[testing_data[i, 0], :] +
                                         rel_embedding[testing_data[i, 2], :]),
                                        n, search_k))
                                tail_ids.append(
                                    t.get_nns_by_vector(
                                        (ent_embedding[testing_data[i, 1], :] -
                                         rel_embedding[testing_data[i, 2], :]),
                                        n, search_k))
                                # except:
                                #     print 'i = %d' % i
                                #     print 'testing_data[i,0] = %d' % testing_data[i,0]
                                #     print 'testing_data[i,1] = %d' % testing_data[i,1]
                                #     print 'testing_data[i,2] = %d' % testing_data[i,2]

                            # print head_ids
                            # print tail_ids
                            evaluation_batch.append(
                                (testing_data, head_ids, tail_ids))
                            evaluation_count += 1

                        et = timeit.default_timer()
                        print 'time taken for head id and tail id calc = %f' % (
                            timeit.default_timer() - st)

                        while evaluation_count > 0:
                            evaluation_count -= 1
                            print 'eval test_type: %s, batch id: %d' % (
                                test_type, evaluation_count)
                            # (mrh, fmrh), (mrt, fmrt) = result_queue.get()
                            (mrh, fmrh), (mrt, fmrt) = worker_func(
                                evaluation_batch[evaluation_count - 1],
                                model.hr_t, model.tr_h)
                            accu_mean_rank_h += mrh
                            accu_mean_rank_t += mrt
                            accu_filtered_mean_rank_h += fmrh
                            accu_filtered_mean_rank_t += fmrt

                        print 'time taken for metric computation = %f' % (
                            timeit.default_timer() - et)

                        print(
                            "[%s] ITER %d [HEAD PREDICTION] MEAN RANK: %.1f FILTERED MEAN RANK %.1f HIT@10 %.3f FILTERED HIT@10 %.3f\n"
                            % (test_type, n_iter, np.mean(accu_mean_rank_h),
                               np.mean(accu_filtered_mean_rank_h),
                               np.mean(
                                   np.asarray(accu_mean_rank_h, dtype=np.int32)
                                   < 10),
                               np.mean(
                                   np.asarray(accu_filtered_mean_rank_h,
                                              dtype=np.int32) < 10)))

                        print(
                            "[%s] ITER %d [TAIL PREDICTION] MEAN RANK: %.1f FILTERED MEAN RANK %.1f HIT@10 %.3f FILTERED HIT@10 %.3f\n"
                            % (test_type, n_iter, np.mean(accu_mean_rank_t),
                               np.mean(accu_filtered_mean_rank_t),
                               np.mean(
                                   np.asarray(accu_mean_rank_t, dtype=np.int32)
                                   < 10),
                               np.mean(
                                   np.asarray(accu_filtered_mean_rank_t,
                                              dtype=np.int32) < 10)))

                        if test_type == 'VALID':
                            filtered_mean_rank = (
                                np.mean(accu_filtered_mean_rank_t) +
                                np.mean(accu_mean_rank_h)) / 2.0
                            if filtered_mean_rank < best_filtered_mean_rank:
                                save_path = saver.save(
                                    session,
                                    os.path.join(
                                        args.save_dir,
                                        "TransE_" + str(args.prefix) + "_" +
                                        str(n_iter) + ".ckpt"))
                                print("Model saved at %s\n" % save_path)
                                best_filtered_mean_rank = filtered_mean_rank