Ejemplo n.º 1
0
    def __init__(self, dim, save_path, keys_path):
        logging.debug('dim: %d', dim)
        logging.debug('save_path: %s', save_path)
        logging.debug('keys_path: %s', keys_path)

        remote_path, save_path = down_if_remote_path(save_path)

        self._remote_path = remote_path
        self._save_path = save_path
        self._index = FaissIndex(dim, save_path)
        self._keys, self._key_index = self._load_keys(keys_path)
        logging.debug('ntotal: %d', self._index.ntotal())
Ejemplo n.º 2
0
def update_index(get_faiss_index, get_faiss_id_to_vector, cur,
                 faiss_index_list):
    file_name = str(os.getpid()) + '.log'
    f = open(file_name, 'a')
    print "\n\n========================================================================"
    f.write(
        "\n\n========================================================================\n"
    )
    f.write('Getting Faiss index of [%s]\n' % cur[0])
    # print "set pid -> %s" % os.getpid()
    print('Getting Faiss index of [%s]' % cur[0])

    faiss_index_list[1 - cur[0]] = FaissIndex(get_faiss_index(),
                                              get_faiss_id_to_vector())
    print 'faiss_index_0:', faiss_index_list[0]
    print 'faiss_index_1:', faiss_index_list[1]
    print('Getting Faiss index done')

    f.write('faiss_index_0:{}\n'.format(faiss_index_list[0]))
    f.write('faiss_index_1:{}\n'.format(faiss_index_list[1]))
    f.write('Getting Faiss index done\n')

    cur[0] = 1 - cur[0]
    print('cur from %s -> %s' % (1 - cur[0], cur[0]))
    print 'current pid:', os.getpid()
    print "========================================================================\n\n"

    f.write('cur from %s -> %s\n' % (1 - cur[0], cur[0]))
    f.write('current pid:%s\n' % os.getpid())
    f.write(
        "========================================================================\n\n"
    )
Ejemplo n.º 3
0
    def __init__(self, dim, save_path, keys_path, nprobe, num_threads=None):
        logging.info("dim: %d", dim)
        logging.info("save_path: %s", save_path)
        logging.info("keys_path: %s", keys_path)
        logging.info("nprobe: %d", nprobe)
        if num_threads is not None:
            logging.info("num_threads: %d", num_threads)

        stream = open("conf.yaml", 'r')
        self._conf = yaml.load(stream, Loader=yaml.FullLoader)
        print(self._conf)

        remote_path, save_path = self.down_if_remote_path(save_path)

        self._remote_path = remote_path
        self._save_path = save_path
        self._index = FaissIndex(dim, save_path, num_threads)
        if nprobe > 1:
            self._index.set_nprobe(nprobe)
        self._keys, self._key_index = self._load_keys(keys_path)
        logging.info("ntotal: %d", self._index.ntotal())
Ejemplo n.º 4
0
def search_new():
    if request.method == "POST":
        try:
            # get form data
            db_dic = {
                "database": request.form["database"],
                "k": request.form["k"],
                "vectors": request.form["vectors"]
            }
            db_name = db_dic['database']
            db_vectors = db_dic['vectors']
            k = int(db_dic["k"])
            try:
                # try to get index
                index = faiss.read_index('%s/index' % str(db_name))
            except Exception:
                return 'there is no index yet!'
            # Get the input vector:string to list
            vectors = []
            feature3 = []
            c = regex.findall('[0-9.]+', db_vectors)
            for i in range(len(c)):
                if c[i] != []:
                    feature3.append(float(c[i]))
                if (i + 1) % 128 == 0:
                    vectors.append(feature3)
                    feature3 = []
            ids2 = (np.arange(len(vectors)) + 1).astype('int')
            id_vector = dict(zip(ids2, vectors))
            # get search research results
            create_db.create_db = FaissIndex(index, id_vector, db_name)
            results = create_db.create_db.search_by_vectors(vectors, k)
            return jsonify(results)
        except Exception as e:
            return 'incorrect input!'
    return render_template("search.html")
Ejemplo n.º 5
0
def record(setup_state):
    blueprint.faiss_index = FaissIndex(
        setup_state.app.config.get('INDEX_PATH'),
        setup_state.app.config.get('IDS_VECTORS_PATH'))
Ejemplo n.º 6
0
 def set_faiss_index(signal=None):
     print('Getting Faiss index')
     blueprint.faiss_index = FaissIndex(get_faiss_index(),
                                        get_faiss_id_to_vector())
Ejemplo n.º 7
0
 def set_faiss_index():
     print('Getting Faiss index')
     blueprint.faiss_index = FaissIndex(get_faiss_index())
Ejemplo n.º 8
0
class FaissServer(pb2_grpc.ServerServicer):
    def __init__(self, dim, save_path, keys_path):
        logging.debug('dim: %d', dim)
        logging.debug('save_path: %s', save_path)
        logging.debug('keys_path: %s', keys_path)

        remote_path, save_path = down_if_remote_path(save_path)

        self._remote_path = remote_path
        self._save_path = save_path
        self._index = FaissIndex(dim, save_path)
        self._keys, self._key_index = self._load_keys(keys_path)
        logging.debug('ntotal: %d', self._index.ntotal())

    def _load_keys(self, keys_path):
        if not keys_path:
            return None, None
        _, keys_path = down_if_remote_path(keys_path)
        keys = pd.read_csv(keys_path, header=None, squeeze=True, dtype=('str'))
        key_index = pd.Index(keys)
        return keys.values, key_index

    def Total(self, request, context):
        return pb2.TotalResponse(count=self._index.ntotal())

    def Add(self, request, context):
        logging.debug('add - id: %d', request.id)
        xb = np.expand_dims(np.array(request.embedding, dtype=np.float32), 0)
        ids = np.array([request.id], dtype=np.int64)
        self._index.replace(xb, ids)

        return pb2.SimpleResponse(message='Added, %d!' % request.id)

    def Remove(self, request, context):
        logging.debug('remove - id: %d', request.id)
        ids = np.array([request.id], dtype=np.int64)
        removed_count = self._index.remove(ids)

        if removed_count < 1:
            return pb2.SimpleResponse(message='Not existed, %s!' % request.id)
        return pb2.SimpleResponse(message='Removed, %s!' % request.id)

    def Search(self, request, context):
        logging.debug('search - id: %d, %s', request.id, request.key)
        if request.key:
            if not self._key_index.contains(request.key):
                return pb2.SearchResponse()
            request.id = self._key_index.get_loc(request.key)
        D, I = self._index.search_by_id(request.id, request.count)
        K = None
        if request.key:
            K = self._keys[I[0]]
        return pb2.SearchResponse(ids=I[0], scores=D[0], keys=K)

    def Restore(self, request, context):
        logging.debug('restore - %s', request.save_path)
        remote_path, save_path = down_if_remote_path(request.save_path)
        self._remote_path = remote_path
        self._save_path = save_path
        self._index.restore(request.save_path)
        return pb2.SimpleResponse(message='Restored, %s!' % request.save_path)

    def Import(self, request, context):
        logging.debug('importing - %s, %s', request.embs_path, request.ids_path)
        _, embs_path = down_if_remote_path(request.embs_path)
        _, ids_path = down_if_remote_path(request.ids_path)
        df = read_csv(embs_path, delimiter="\t", header=None)
        X = df.values
        df = read_csv(ids_path, header=None)
        ids = df[0].values
        logging.debug('%s', ids)

        X = np.ascontiguousarray(X, dtype=np.float32)
        ids = np.ascontiguousarray(ids, dtype=np.int64)

        self._index.replace(X, ids)
        return pb2.SimpleResponse(message='Imported, %s, %s!' % (request.embs_path, request.ids_path))

    def save(self):
        logging.debug('saving index to %s', self._save_path)
        self._index.save(self._save_path)
Ejemplo n.º 9
0
    def _new_trained_index(self):
        def path_to_id(filepath):
            pos = filepath.rindex('/') + 1
            return int(filepath[pos:-4])

        logging.info("File loading...")
        t0 = time.time()
        all_filepaths = glob.glob('embeddings/*/*.emb')
        total_count = len(all_filepaths)
        logging.info("%d files %.3f s", total_count, time.time() - t0)

        train_count = min(total_count, self.max_train_count)
        if train_count <= 0:
            return self._new_index()

        random.shuffle(all_filepaths)

        filepaths = all_filepaths[:train_count]
        t0 = time.time()
        xb = self._path_to_xb(filepaths)
        ids = np.array(list(map(path_to_id, filepaths)), dtype=np.int64)
        logging.info("%d embeddings loaded %.3f s", xb.shape[0], time.time() - t0)

        if train_count < 10000:
            d = self.embedding_service.dim()
            faiss_index = FaissIndex(d)
            faiss_index.add(xb, ids)
            return faiss_index

        nlist = min(self._max_nlist, int(train_count / 39))
        faiss_index = self._new_index(nlist=nlist)

        logging.info("Training...")
        t0 = time.time()
        faiss_index.train(xb)
        logging.info("trained %.3f s", time.time() - t0)

        step = 100000
        for i in range(0, train_count, step):
            t0 = time.time()
            faiss_index.add(xb[i:i+step], ids[i:i+step])
            logging.info("added %.3f s", time.time() - t0)

        if total_count > train_count:
            for filepaths in chunks(all_filepaths[train_count:], 20000):
                t0 = time.time()
                xb = self._path_to_xb(filepaths)
                ids = np.array(list(map(path_to_id, filepaths)), dtype=np.int64)
                faiss_index.add(xb, ids)
                logging.info("%d embeddings added %.3f s", xb.shape[0], time.time() - t0)
            logging.info("Total %d embeddings added", faiss_index.ntotal())
        return faiss_index
Ejemplo n.º 10
0
class FaissServer(pb2_grpc.ServerServicer):
    def __init__(self, dim, save_path, keys_path, nprobe, num_threads=None):
        logging.info("dim: %d", dim)
        logging.info("save_path: %s", save_path)
        logging.info("keys_path: %s", keys_path)
        logging.info("nprobe: %d", nprobe)
        if num_threads is not None:
            logging.info("num_threads: %d", num_threads)

        stream = open("conf.yaml", 'r')
        self._conf = yaml.load(stream, Loader=yaml.FullLoader)
        print(self._conf)

        remote_path, save_path = self.down_if_remote_path(save_path)

        self._remote_path = remote_path
        self._save_path = save_path
        self._index = FaissIndex(dim, save_path, num_threads)
        if nprobe > 1:
            self._index.set_nprobe(nprobe)
        self._keys, self._key_index = self._load_keys(keys_path)
        logging.info("ntotal: %d", self._index.ntotal())

    def parse_remote_path(self, save_path):
        if save_path is None or (not save_path.startswith("s3://")
                                 and not save_path.startswith("blobs://")):
            return None, save_path
        remote_path = save_path
        filename = os.path.basename(remote_path)
        save_path = "%s/%d-%s" % (gettempdirb().decode("utf-8"), time(),
                                  filename)
        return remote_path, save_path

    def down_if_remote_path(self, save_path):
        remote_path, local_path = self.parse_remote_path(save_path)
        if not remote_path:
            return None, local_path
        logging.debug("remote_path=%s", remote_path)
        if remote_path.startswith("s3://"):
            s3 = boto3.resource("s3")
            tokens = remote_path.replace("s3://", "").split("/")
            bucket_name = tokens[0]
            key = "/".join(tokens[1:])
            s3.Bucket(bucket_name).download_file(key, local_path)
        elif remote_path.startswith("blobs://"):
            blob_service = BlockBlobService(
                account_name=self._conf["azure_blobs"]["storage.account"],
                account_key=self._conf["azure_blobs"]["account.key"])
            container_name = self._conf["azure_blobs"]["container"]
            remote_path = remote_path.replace("blobs://", "")
            prefix = remote_path
            generator = blob_service.list_blobs(container_name, prefix=prefix)

            fp = open(local_path, "ab")
            for blob in generator:
                # Using `get_blob_to_bytes`
                b = blob_service.get_blob_to_bytes(container_name, blob.name)
                fp.write(b.content)
                # Or using `get_blob_to_stream`
                # service.get_blob_to_stream(container_name, blob.name, fp)

            fp.flush()
            fp.close()

        return remote_path, local_path

    def _load_keys(self, keys_path):
        if not keys_path:
            return None, None
        _, keys_path = self.down_if_remote_path(keys_path)
        keys = pd.read_csv(keys_path, header=None, squeeze=True, dtype=("str"))
        key_index = pd.Index(keys)
        logging.debug("keys: keys[size=%d]=%s, keys_index[size=%d]=%s",
                      len(keys), keys.values[:10], len(key_index),
                      key_index[:10])
        return keys.values, key_index

    def Total(self, request, context):
        return pb2.TotalResponse(count=self._index.ntotal())

    def Add(self, request, context):
        logging.debug("add - id: %d, %s", request.id, request.key)
        if request.key:
            # if self._key_index is None or not self._key_index.contains(request.key):
            if self._key_index is None or request.key not in self._key_index:
                if self._key_index is None:
                    self._key_index = pd.Index([request.key])
                else:
                    self._key_index = self._key_index.append(
                        pd.Index([request.key]))

                request.id = self._key_index.get_loc(request.key)
                if self._keys is None:
                    self._keys = np.array([request.key])
                else:
                    self._keys = np.append(self._keys, [request.key])
            else:
                request.id = self._key_index.get_loc(request.key)

        # For debugging
        # if self._keys is not None and self._key_index is not None:
        #     logging.debug("keys: keys=%s, keys_index=%s", self._keys, self._key_index)

        xb = np.expand_dims(np.array(request.embedding, dtype=np.float32), 0)
        ids = np.array([request.id], dtype=np.int64)
        self._index.replace(xb, ids)

        return pb2.SimpleResponse(message="Added, %d!" % request.id)

    def Remove(self, request, context):
        logging.debug("remove - id: %d", request.id)
        ids = np.array([request.id], dtype=np.int64)
        removed_count = self._index.remove(ids)

        if removed_count < 1:
            return pb2.SimpleResponse(message="Not existed, %s!" % request.id)
        return pb2.SimpleResponse(message="Removed, %s!" % request.id)

    def Search(self, request, context):
        if request.key:
            # if self._key_index is None or not self._key_index.contains(request.key):
            if self._key_index is None or request.key not in self._key_index:
                logging.debug("search - Key not found: %s", request.key)
                return pb2.SearchResponse()
            request.id = self._key_index.get_loc(request.key)
        # logging.debug("search - id: %d, %s", request.id, request.key)

        D, I = self._index.search_by_id(request.id, request.count)
        K = None
        if self._keys is not None:
            K = self._keys[I[0]]
        return pb2.SearchResponse(ids=I[0], scores=D[0], keys=K)

    def SearchByEmbedding(self, request, context):
        # logging.debug("search_by_emb - embedding: %s", request.embedding[:10])
        emb = np.array(request.embedding, dtype=np.float32)
        emb = np.expand_dims(emb, axis=0)
        D, I = self._index.search(emb, request.count)
        K = None
        if self._keys is not None:
            K = self._keys[I[0]]
        return pb2.SearchResponse(ids=I[0], scores=D[0], keys=K)

    def GetEmbedding(self, request, context):
        if request.key:
            if self._key_index is None or request.key not in self._key_index:
                logging.debug("getEmbedding - Key not found: %s", request.key)
                return pb2.EmbeddingResponse()
            request.id = self._key_index.get_loc(request.key)
            # logging.debug("*** GetEmbedding: request.id = {} of request.key = {}".format(request.id, request.key))

        emb = self._index.reconstruct(request.id)
        if emb is not None:
            emb = emb.flatten()
        return pb2.EmbeddingResponse(embedding=emb)

    def Restore(self, request, context):
        logging.debug("restore - %s", request.save_path)
        remote_path, save_path = self.down_if_remote_path(request.save_path)
        self._remote_path = remote_path
        self._save_path = save_path
        self._index.restore(request.save_path)
        return pb2.SimpleResponse(message="Restored, %s!" % request.save_path)

    def Reset(self, request, context):
        logging.debug("reset")
        self._index.reset()
        self._keys = None
        self._key_index = None
        return pb2.SimpleResponse(message="Reset!")

    def Import(self, request, context):
        logging.info("importing - %s, %s, %s", request.embs_path,
                     request.ids_path, request.keys_path)
        _, embs_path = self.down_if_remote_path(request.embs_path)
        _, ids_path = self.down_if_remote_path(request.ids_path)
        _, keys_path = self.down_if_remote_path(request.keys_path)
        df = pd.read_csv(embs_path, delimiter="\t", header=None)
        X = df.values
        # logging.debug("X = %s", X)
        df = pd.read_csv(ids_path, header=None)
        ids = df[0].values
        logging.info("ids[size=%d] = %s", len(ids), ids)

        X = np.ascontiguousarray(X, dtype=np.float32)
        ids = np.ascontiguousarray(ids, dtype=np.int64)

        # self._index.replace(X, ids)
        self._index.rebuild(X, ids)

        self._keys, self._key_index = self._load_keys(keys_path)

        return pb2.SimpleResponse(
            message="Imported, %s, %s, %s!" %
            (request.embs_path, request.ids_path, request.keys_path))

    def save(self):
        logging.debug("saving index to %s", self._save_path)
        self._index.save(self._save_path)