Esempio n. 1
0
def _faiss_knn(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int,
               distance: str) -> Tuple[torch.Tensor, torch.Tensor]:
    # https://github.com/facebookresearch/XLM/blob/master/src/model/memory/utils.py
    if not is_faiss_available():
        raise RuntimeError("faiss_knn requires faiss-gpu")
    import faiss

    assert distance in ['dot_product', 'l2']
    assert keys.size(1) == queries.size(1)

    metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2

    k_ptr = _tensor_to_ptr(keys)
    q_ptr = _tensor_to_ptr(queries)

    scores = keys.new_zeros((queries.size(0), num_neighbors),
                            dtype=torch.float32)
    indices = keys.new_zeros((queries.size(0), num_neighbors),
                             dtype=torch.int64)

    s_ptr = _tensor_to_ptr(scores)
    i_ptr = _tensor_to_ptr(indices)

    faiss.bfKnn(FAISS_RES, metric, k_ptr, True, keys.size(0), q_ptr, True,
                queries.size(0), queries.size(1), num_neighbors, s_ptr, i_ptr)
    return scores, indices
Esempio n. 2
0
    def __init__(self,
                 emb_dim: int,
                 dict_size: int,
                 momentum: float = 0.99,
                 epsilon: float = 1e-5,
                 knn_backend="faiss" if is_faiss_available() else "torch",
                 metric: str = 'l2'):

        super(VQModule, self).__init__()

        self.emb_dim = emb_dim
        self.dict_size = dict_size
        self.epsilon = epsilon
        self._knn_backend = knn_backend
        self.metric = metric
        self.frozen = False
        # this handles the issue with DataParallel

        assert 0 <= momentum <= 1
        self.gamma = momentum

        # embed: DxC (emb_dim==C)
        embed = F.normalize(torch.randn(dict_size, emb_dim), dim=1, p=2)
        self.register_buffer("track_num", torch.zeros(dict_size, 1))
        self.register_buffer("track_enc", embed.clone())
        self.register_buffer("embed", embed)
        self._first_time = True

        self._distributed_update()
Esempio n. 3
0
def faiss_knn(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int,
              distance: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """ k nearest neighbor using faiss. Users are recommended to use `k_nearest_neighbor` instead.

    :param keys: tensor of (num_keys, dim)
    :param queries: tensor of (num_queries, dim)
    :param num_neighbors: `k`
    :param distance: user can use str or faiss.METRIC_*.
    :return: scores, indices in tensor
    """

    if not is_faiss_available():
        raise RuntimeError("_faiss_knn requires faiss-gpu")

    metric_map = {
        "inner_product": faiss.METRIC_INNER_PRODUCT,
        "l2": faiss.METRIC_L2,
        "l1": faiss.METRIC_L1,
        "linf": faiss.METRIC_Linf,
        "jansen_shannon": faiss.METRIC_JensenShannon
    }

    k_ptr = _tensor_to_ptr(keys)
    q_ptr = _tensor_to_ptr(queries)

    scores = keys.new_empty((queries.size(0), num_neighbors),
                            dtype=torch.float32)
    indices = keys.new_empty((queries.size(0), num_neighbors),
                             dtype=torch.int64)

    s_ptr = _tensor_to_ptr(scores)
    i_ptr = _tensor_to_ptr(indices)

    args = faiss.GpuDistanceParams()
    args.metric = metric_map[distance] if isinstance(distance,
                                                     str) else distance
    args.k = num_neighbors
    args.dims = queries.size(1)
    args.vectors = k_ptr
    args.vectorsRowMajor = True
    args.numVectors = keys.size(0)
    args.queries = q_ptr
    args.queriesRowMajor = True
    args.numQueries = queries.size(0)
    args.outDistances = s_ptr
    args.outIndices = i_ptr
    faiss.bfKnn(FAISS_RES, args)
    return scores, indices
Esempio n. 4
0
def k_nearest_neighbor(
        keys: torch.Tensor,
        queries: torch.Tensor,
        num_neighbors: int,
        distance: str,
        *,
        backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]:
    """ k-Nearest Neighbor search

    :param keys: tensor of (num_keys, dim)
    :param queries: tensor of (num_queries, dim)
    :param num_neighbors: `k`
    :param distance: registry_name of distance (`dot_product` or `l2`)
    :param backend: backend (`faiss` or `torch`)
    :return: scores, indices
    """
    assert backend in ["faiss", "torch"]
    f = _faiss_knn if backend == "faiss" and is_faiss_available(
    ) else _torch_knn
    return f(keys, queries, num_neighbors, distance)
Esempio n. 5
0
def k_nearest_neighbor(
        keys: torch.Tensor,
        queries: torch.Tensor,
        num_neighbors: int,
        distance: str,
        *,
        backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]:
    """ k-Nearest Neighbor search. Faiss backend requires GPU. torch backend is JITtable

    :param keys: tensor of (num_keys, dim)
    :param queries: tensor of (num_queries, dim)
    :param num_neighbors: `k`
    :param distance: name of distance (`inner_product` or `l2`). Faiss backend additionally supports `l1`, `linf`, `jansen_shannon`.
    :param backend: backend (`faiss` or `torch`)
    :return: scores, indices
    """

    assert backend in {"faiss", "torch", "torch_jit"}
    assert keys.size(1) == queries.size(1)
    assert keys.ndim == 2 and queries.ndim == 2

    f = faiss_knn if backend == "faiss" and is_faiss_available() else torch_knn
    return f(keys, queries, num_neighbors, distance)
Esempio n. 6
0

def k_nearest_neighbor(
        keys: torch.Tensor,
        queries: torch.Tensor,
        num_neighbors: int,
        distance: str,
        *,
        backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]:
    """ k-Nearest Neighbor search

    :param keys: tensor of (num_keys, dim)
    :param queries: tensor of (num_queries, dim)
    :param num_neighbors: `k`
    :param distance: registry_name of distance (`dot_product` or `l2`)
    :param backend: backend (`faiss` or `torch`)
    :return: scores, indices
    """
    assert backend in ["faiss", "torch"]
    f = _faiss_knn if backend == "faiss" and is_faiss_available(
    ) else _torch_knn
    return f(keys, queries, num_neighbors, distance)


if is_faiss_available():
    import faiss

    FAISS_RES = faiss.StandardGpuResources()
    FAISS_RES.setDefaultNullStreamAllDevices()
    FAISS_RES.setTempMemory(1200 * 1024 * 1024)