def test_init_rp_tree(): # Use more data than the other tests since otherwise init_rp_tree has nothing to do np.random.seed(42) N = 100 D = 128 chunk_size = N // 8 n_neighbors = 25 data = np.random.rand(N, D).astype(np.float32) rng_state = new_rng_state() current_graph = pynndescent_.init_current_graph( data, dist, dist_args, n_neighbors, rng_state=rng_state, seed_per_row=True ) _rp_forest = make_forest(data, n_neighbors, n_trees=8, rng_state=rng_state) leaf_array = rptree_leaf_array(_rp_forest) pynndescent_.init_rp_tree(data, dist, dist_args, current_graph, leaf_array) rng_state = new_rng_state() current_graph_threaded = pynndescent_.init_current_graph( data, dist, dist_args, n_neighbors, rng_state=rng_state, seed_per_row=True ) _rp_forest = make_forest(data, n_neighbors, n_trees=8, rng_state=rng_state) leaf_array = rptree_leaf_array(_rp_forest) parallel = joblib.Parallel(n_jobs=2, prefer="threads") threaded.init_rp_tree( data, dist, dist_args, current_graph_threaded, leaf_array, chunk_size, parallel ) assert_allclose(current_graph_threaded, current_graph)
def test_init_rp_tree(): # Use more graph_data than the other tests since otherwise init_rp_tree has nothing to do np.random.seed(42) N = 100 D = 128 chunk_size = N // 8 n_neighbors = 25 data = np.random.rand(N, D).astype(np.float32) rng_state = new_rng_state() random_state = check_random_state(42) current_graph = utils.make_heap(data.shape[0], n_neighbors) _rp_forest = make_forest( data, n_neighbors, n_trees=8, leaf_size=None, rng_state=rng_state, random_state=random_state, ) leaf_array = rptree_leaf_array(_rp_forest) pynndescent_.init_rp_tree(data, dist, current_graph, leaf_array) rng_state = new_rng_state() random_state = check_random_state(42) current_graph_threaded = utils.make_heap(data.shape[0], n_neighbors) _rp_forest = make_forest( data, n_neighbors, n_trees=8, leaf_size=None, rng_state=rng_state, random_state=random_state, ) leaf_array = rptree_leaf_array(_rp_forest) parallel = joblib.Parallel(n_jobs=2, prefer="threads") threaded.init_rp_tree(data, dist, current_graph_threaded, leaf_array, chunk_size, parallel) assert_allclose(current_graph_threaded, current_graph)