예제 #1
0
    def kneighbors(self, X, n_neighbors=None, return_distance=True):
        """Finds the K-neighbors of a point.

        Returns indices of and distances to the neighbors of each
        point.

        Parameters
        ----------
        X: numpy.array 1D
          Array of strings containing the words whose neighbors should
          be determined.

        n_neighbors: int
          Number of neighbors to get (default is the value passed to
          the constructor).

        return_distance: boolean (default=True)
          If False, distances will not be returned.

        Returns
        -------
        neighbors: numpy.array
          Indices of the nearest points in the population matrix.

        distances: numpy.array (optional)
          Array representing the lengths to points, only present if
          `return_distance=True`.

        """
        n_neighbors = n_neighbors or self.topn
        n_neighbors = min(n_neighbors, self.syn0_.shape[0] - 1)
        neighbors, similarities = [], []

        for x in X:
            xvec = self._get_vector_from_word(x)
            similarity = np.dot(self.syn0_, xvec.T)

            # Throw away the smallest index, since it is the initial
            # word itself.
            # pylint: disable=invalid-unary-operand-type
            neighbor_indices = fast_argsort(-similarity, n_neighbors + 1)[1:]

            neighbors.append(neighbor_indices)
            similarities.append(similarity[neighbor_indices])

        neighbors = np.vstack(neighbors)
        if not return_distance:
            return neighbors

        # normalize distances to [0, 1]
        distances = (np.vstack(similarities) - 1) / -2.
        return neighbors, distances
예제 #2
0
 def test_compare_fast_argsort_with_argsort_only_n_smallest(
         self, vec, n, fast_argsort):
     slow = np.argsort(vec)[:n]
     fast = fast_argsort(vec, n)
     assert (slow == fast).all()
예제 #3
0
 def test_compare_fast_argsort_with_argsort_all(self, vec, fast_argsort):
     slow = np.argsort(vec)
     fast = fast_argsort(vec, len(vec))
     assert (slow == fast).all()