def _faiss_knn(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str) -> Tuple[torch.Tensor, torch.Tensor]: # https://github.com/facebookresearch/XLM/blob/master/src/model/memory/utils.py if not is_faiss_available(): raise RuntimeError("faiss_knn requires faiss-gpu") import faiss assert distance in ['dot_product', 'l2'] assert keys.size(1) == queries.size(1) metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2 k_ptr = _tensor_to_ptr(keys) q_ptr = _tensor_to_ptr(queries) scores = keys.new_zeros((queries.size(0), num_neighbors), dtype=torch.float32) indices = keys.new_zeros((queries.size(0), num_neighbors), dtype=torch.int64) s_ptr = _tensor_to_ptr(scores) i_ptr = _tensor_to_ptr(indices) faiss.bfKnn(FAISS_RES, metric, k_ptr, True, keys.size(0), q_ptr, True, queries.size(0), queries.size(1), num_neighbors, s_ptr, i_ptr) return scores, indices
def __init__(self, emb_dim: int, dict_size: int, momentum: float = 0.99, epsilon: float = 1e-5, knn_backend="faiss" if is_faiss_available() else "torch", metric: str = 'l2'): super(VQModule, self).__init__() self.emb_dim = emb_dim self.dict_size = dict_size self.epsilon = epsilon self._knn_backend = knn_backend self.metric = metric self.frozen = False # this handles the issue with DataParallel assert 0 <= momentum <= 1 self.gamma = momentum # embed: DxC (emb_dim==C) embed = F.normalize(torch.randn(dict_size, emb_dim), dim=1, p=2) self.register_buffer("track_num", torch.zeros(dict_size, 1)) self.register_buffer("track_enc", embed.clone()) self.register_buffer("embed", embed) self._first_time = True self._distributed_update()
def faiss_knn(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str) -> Tuple[torch.Tensor, torch.Tensor]: """ k nearest neighbor using faiss. Users are recommended to use `k_nearest_neighbor` instead. :param keys: tensor of (num_keys, dim) :param queries: tensor of (num_queries, dim) :param num_neighbors: `k` :param distance: user can use str or faiss.METRIC_*. :return: scores, indices in tensor """ if not is_faiss_available(): raise RuntimeError("_faiss_knn requires faiss-gpu") metric_map = { "inner_product": faiss.METRIC_INNER_PRODUCT, "l2": faiss.METRIC_L2, "l1": faiss.METRIC_L1, "linf": faiss.METRIC_Linf, "jansen_shannon": faiss.METRIC_JensenShannon } k_ptr = _tensor_to_ptr(keys) q_ptr = _tensor_to_ptr(queries) scores = keys.new_empty((queries.size(0), num_neighbors), dtype=torch.float32) indices = keys.new_empty((queries.size(0), num_neighbors), dtype=torch.int64) s_ptr = _tensor_to_ptr(scores) i_ptr = _tensor_to_ptr(indices) args = faiss.GpuDistanceParams() args.metric = metric_map[distance] if isinstance(distance, str) else distance args.k = num_neighbors args.dims = queries.size(1) args.vectors = k_ptr args.vectorsRowMajor = True args.numVectors = keys.size(0) args.queries = q_ptr args.queriesRowMajor = True args.numQueries = queries.size(0) args.outDistances = s_ptr args.outIndices = i_ptr faiss.bfKnn(FAISS_RES, args) return scores, indices
def k_nearest_neighbor( keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str, *, backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]: """ k-Nearest Neighbor search :param keys: tensor of (num_keys, dim) :param queries: tensor of (num_queries, dim) :param num_neighbors: `k` :param distance: registry_name of distance (`dot_product` or `l2`) :param backend: backend (`faiss` or `torch`) :return: scores, indices """ assert backend in ["faiss", "torch"] f = _faiss_knn if backend == "faiss" and is_faiss_available( ) else _torch_knn return f(keys, queries, num_neighbors, distance)
def k_nearest_neighbor( keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str, *, backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]: """ k-Nearest Neighbor search. Faiss backend requires GPU. torch backend is JITtable :param keys: tensor of (num_keys, dim) :param queries: tensor of (num_queries, dim) :param num_neighbors: `k` :param distance: name of distance (`inner_product` or `l2`). Faiss backend additionally supports `l1`, `linf`, `jansen_shannon`. :param backend: backend (`faiss` or `torch`) :return: scores, indices """ assert backend in {"faiss", "torch", "torch_jit"} assert keys.size(1) == queries.size(1) assert keys.ndim == 2 and queries.ndim == 2 f = faiss_knn if backend == "faiss" and is_faiss_available() else torch_knn return f(keys, queries, num_neighbors, distance)
def k_nearest_neighbor( keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str, *, backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]: """ k-Nearest Neighbor search :param keys: tensor of (num_keys, dim) :param queries: tensor of (num_queries, dim) :param num_neighbors: `k` :param distance: registry_name of distance (`dot_product` or `l2`) :param backend: backend (`faiss` or `torch`) :return: scores, indices """ assert backend in ["faiss", "torch"] f = _faiss_knn if backend == "faiss" and is_faiss_available( ) else _torch_knn return f(keys, queries, num_neighbors, distance) if is_faiss_available(): import faiss FAISS_RES = faiss.StandardGpuResources() FAISS_RES.setDefaultNullStreamAllDevices() FAISS_RES.setTempMemory(1200 * 1024 * 1024)