def _check_p_distance_vs_KDT(self, p): bt = BallTree(self.X, leaf_size=10, metric='minkowski', p=p) kdt = cKDTree(self.X, leafsize=10) dist_bt, ind_bt = bt.query(self.X, k=5) dist_kd, ind_kd = kdt.query(self.X, k=5, p=p) assert_array_almost_equal(dist_bt, dist_kd)
def _check_metrics_float(self, k, metric, kwargs): bt = BallTree(self.X, metric=metric, **kwargs) dist_bt, ind_bt = bt.query(self.X, k=k) dm = DistanceMetric(metric=metric, **kwargs) D = dm.pdist(self.X, squareform=True) ind_dm = np.argsort(D, 1)[:, :k] dist_dm = D[np.arange(self.X.shape[0])[:, None], ind_dm] # we don't check the indices here because if there is a tie for # nearest neighbor, then the test may fail. Distances will reflect # whether the search was successful assert_array_almost_equal(dist_bt, dist_dm)
def _check_metrics_bool(self, k, metric, kwargs): bt = BallTree(self.Xbool, metric=metric, **kwargs) dist_bt, ind_bt = bt.query(self.Ybool, k=k) dm = DistanceMetric(metric=metric, **kwargs) D = dm.cdist(self.Ybool, self.Xbool) ind_dm = np.argsort(D, 1)[:, :k] dist_dm = D[np.arange(self.Ybool.shape[0])[:, None], ind_dm] # we don't check the indices here because there are very often # ties for nearest neighbors, which cause the test to fail. # Distances will be correct in either case. assert_array_almost_equal(dist_bt, dist_dm)
def test_query_radius_count(self): # center the data X = 2 * self.X - 1 dm = DistanceMetric() D = dm.pdist(X, squareform=True) r = np.mean(D) bt = BallTree(X) count1 = bt.query_radius(X, r, count_only=True) count2 = (D <= r).sum(1) assert_array_almost_equal(count1, count2)
def test_query_radius_indices(self, n_queries=20): # center the data X = 2 * self.X - 1 dm = DistanceMetric() D = dm.cdist(X[:n_queries], X) r = np.mean(D) bt = BallTree(X) ind = bt.query_radius(X[:n_queries], r, return_distance=False) ind2 = np.zeros(D.shape) + np.arange(D.shape[1]) ind = np.concatenate(map(np.sort, ind)) ind2 = ind2[D <= r] assert_array_almost_equal(ind, ind2)
def test_query_radius_distance(self): # center the data X = 2 * self.X - 1 # choose a query point near the origin query_pt = 0.01 * X[:1] eps = 1E-15 # roundoff error can cause test to fail bt = BallTree(X, leaf_size=5) # compute reference distances dm = DistanceMetric() dist_true = dm.cdist(query_pt, X)[0] dist_true.sort() for r in np.linspace(dist_true[0], dist_true[-1], 10): yield (self._check_query_radius_distance, X, bt, query_pt, dist_true, r, eps)
def test_pickle(self): bt1 = BallTree(self.X, leaf_size=1) ind1, dist1 = bt1.query(self.X) for protocol in (0, 1, 2): yield (self._check_pickle, protocol, bt1, ind1, dist1)
def test_query_knn(self): bt = BallTree(self.X) kdt = cKDTree(self.X) for k in (1, 2, 4, 8, 16): for dualtree in [True, False]: yield (self._check_query_knn, bt, kdt, k, dualtree)