Beispiel #1
0
    def Remove(self, request, context):
        logging.debug('remove - id: %d', request.id)
        ids = np.array([request.id], dtype=np.int64)
        self.faiss_index.remove_ids(ids)

        filepath = self._get_filepath(request.id) 
        if file_io.file_exists(filepath):
            file_io.delete_file(filepath)
            return pb2.SimpleReponse(message='Removed, %s!' % request.id)

        return pb2.SimpleReponse(message='Not existed, %s!' % request.id)
Beispiel #2
0
    def Add(self, request, context):
        logging.debug('add - id: %d', request.id)
        if self._more_recent_emb_file_exists(request):
            return pb2.SimpleReponse(message='Already added, %s!' % request.id)

        embedding = self.fetch_embedding(request)
        if embedding is None:
            return pb2.SimpleReponse(message='No embedding, id: %d, url: %s' % (request.id, request.url))

        embedding = np.expand_dims(embedding, 0)
        ids = np.array([request.id], dtype=np.int64)
        self.faiss_index.replace(embedding, ids)

        return pb2.SimpleReponse(message='Added, %s!' % request.id)
Beispiel #3
0
 def TrainCluster(self, request, context):
     t0 = time.time()
     logging.info("kmeans_index training..")
     self._kmeans_index = self._train_kmeans(request.ncentroids)
     faiss.write_index(self._kmeans_index, request.save_filepath)
     logging.info("kmeans_index loaded %.2f s", time.time() - t0)
     return pb2.SimpleReponse(message='clustered')
Beispiel #4
0
    def Import(self, request, context):
        def get_mtime(filepath):
            if file_io.file_exists(filepath):
                return file_io.stat(filepath).mtime_nsec
            return None

        def is_new_emb(id, filepath):
            origin_mtime = get_mtime(self._get_filepath(id))
            if origin_mtime is None:
                return True
            new_mtime = get_mtime(filepath)
            return origin_mtime < new_mtime

        logging.info("Importing..")
        all_filepaths = list(glob.iglob('%s/*.emb' % request.path))

        total_count = len(all_filepaths)
        if total_count <= 0:
            logging.info("No files for importing!")
            return pb2.SimpleReponse(message='No files for importing!')

        logging.info("Importing files count: %d" % total_count)

        pos = len(request.path) + 1
        def path_to_id(filepath):
            return int(filepath[pos:-4])

        for filepaths in chunks(all_filepaths, 10000):
            t0 = time.time()

            ids = map(path_to_id, filepaths)
            ids_filepaths = [(id, filepath) for id, filepath in zip(ids, filepaths) if is_new_emb(id, filepath)]

            xb = self._path_to_xb([filepath for _, filepath in ids_filepaths])
            ids = np.array([id for id, _ in ids_filepaths], dtype=np.int64)
            self.faiss_index.replace(xb, ids)

            for id, filepath in ids_filepaths:
                file_io.rename(filepath, self._get_filepath(id, mkdir=True), overwrite=True)

            logging.info("%d embeddings added %.3f s", xb.shape[0], time.time() - t0)
        return pb2.SimpleReponse(message='Imported, %d!' % total_count)
Beispiel #5
0
    def Fetch(self, request, context):
        total_count = len(request.items)
        fetched_count = 0

        results = []

        pool = ThreadPool(12)
        for item in request.items:
            result = pool.spawn(self.fetch_embedding, item)
            results.append(result)
        gevent.wait()

        for result in results:
            if result.get() is not None:
                fetched_count += 1

        return pb2.SimpleReponse(message='Fetched, %d of %d!' % (fetched_count, total_count))
Beispiel #6
0
 def Train(self, request, context):
     pre_index = self.faiss_index
     self.faiss_index = self._new_trained_index()
     pre_index.reset()
     return pb2.SimpleReponse(message='Trained')
Beispiel #7
0
 def Migrate(self, request, context):
     logging.info('Migrating...')
     t0 = time.time()
     self._sync()
     logging.info("Migrated %.2f s", time.time() - t0)
     return pb2.SimpleReponse(message='Migrated.')
Beispiel #8
0
 def Info(self, request, context):
     return pb2.SimpleReponse(message='%s' % self.faiss_index.ntotal())
Beispiel #9
0
 def Save(self, request, context):
     self.save()
     return pb2.SimpleReponse(message='Saved')