def test_gradient_bh_multithread_match_sequential(): # check that the bh gradient with different num_threads gives the same # results n_features = 10 n_samples = 30 n_components = 2 degrees_of_freedom = 1 angle = 3 perplexity = 5 random_state = check_random_state(0) data = random_state.randn(n_samples, n_features).astype(np.float32) params = random_state.randn(n_samples, n_components) n_neighbors = n_samples - 1 distances_csr = NearestNeighbors().fit(data).kneighbors_graph( n_neighbors=n_neighbors, mode='distance') P_bh = _joint_probabilities_nn(distances_csr, perplexity, verbose=0) kl_sequential, grad_sequential = _kl_divergence_bh( params, P_bh, degrees_of_freedom, n_samples, n_components, angle=angle, skip_num_points=0, verbose=0, num_threads=1) for num_threads in [2, 4]: kl_multithread, grad_multithread = _kl_divergence_bh( params, P_bh, degrees_of_freedom, n_samples, n_components, angle=angle, skip_num_points=0, verbose=0, num_threads=num_threads) assert_allclose(kl_multithread, kl_sequential, rtol=1e-6) assert_allclose(grad_multithread, grad_multithread)
def test_barnes_hut_angle(): # When Barnes-Hut's angle=0 this corresponds to the exact method. angle = 0.0 perplexity = 10 n_samples = 100 for n_components in [2, 3]: n_features = 5 degrees_of_freedom = float(n_components - 1.0) random_state = check_random_state(0) data = random_state.randn(n_samples, n_features) distances = pairwise_distances(data) params = random_state.randn(n_samples, n_components) P = _joint_probabilities(distances, perplexity, verbose=0) kl_exact, grad_exact = _kl_divergence(params, P, degrees_of_freedom, n_samples, n_components) n_neighbors = n_samples - 1 distances_csr = NearestNeighbors().fit(data).kneighbors_graph( n_neighbors=n_neighbors, mode='distance') P_bh = _joint_probabilities_nn(distances_csr, perplexity, verbose=0) kl_bh, grad_bh = _kl_divergence_bh(params, P_bh, degrees_of_freedom, n_samples, n_components, angle=angle, skip_num_points=0, verbose=0) P = squareform(P) P_bh = P_bh.toarray() assert_array_almost_equal(P_bh, P, decimal=5) assert_almost_equal(kl_exact, kl_bh, decimal=3)