Esempio n. 1
0
    def testKNeighborsGraphExecution(self):
        rs = np.random.RandomState(0)
        raw_X = rs.rand(10, 5)
        raw_Y = rs.rand(8, 5)

        X = mt.tensor(raw_X, chunk_size=7)
        Y = mt.tensor(raw_Y, chunk_size=(5, 3))

        neigh = NearestNeighbors(n_neighbors=3)
        neigh.fit(X)
        sklearn_neigh = SkNearestNeighbors(n_neighbors=3)
        sklearn_neigh.fit(raw_X)

        for mode in ['connectivity', 'distance']:
            graph = neigh.kneighbors_graph(Y, mode=mode)
            result = graph.fetch()

            self.assertIsInstance(result, SparseNDArray)
            self.assertGreater(len(get_tiled(graph).chunks), 1)

            expected = sklearn_neigh.kneighbors_graph(raw_Y, mode=mode)

            np.testing.assert_array_equal(result.toarray(),
                                          expected.toarray())

            graph2 = neigh.kneighbors_graph(mode=mode)
            result2 = graph2.fetch()

            self.assertIsInstance(result2, SparseNDArray)
            self.assertGreater(len(get_tiled(graph2).chunks), 1)

            expected2 = sklearn_neigh.kneighbors_graph(mode=mode)

            np.testing.assert_array_equal(result2.toarray(),
                                          expected2.toarray())

        X = [[0], [3], [1]]

        neigh = NearestNeighbors(n_neighbors=2)
        sklearn_neigh = SkNearestNeighbors(n_neighbors=2)
        neigh.fit(X)
        sklearn_neigh.fit(X)

        A = neigh.kneighbors_graph(X).fetch()
        expected_A = sklearn_neigh.kneighbors_graph(X)
        np.testing.assert_array_equal(A.toarray(), expected_A.toarray())

        # test wrong mode
        with self.assertRaises(ValueError):
            _ = neigh.kneighbors_graph(mode='unknown')
Esempio n. 2
0
 def topk_rbf(X, Y=None, n_neighbors=10, gamma=1e-5):
     nn = NearestNeighbors(n_neighbors=10, metric='euclidean', n_jobs=-1)
     nn.fit(X)
     W = -1 * mt.power(nn.kneighbors_graph(Y, mode='distance'), 2) * gamma
     W = mt.exp(W)
     assert W.issparse()
     return W.T