def test_update_w_prepare_query_accuracy(nn_data, metric): nnd = NNDescent( nn_data[200:800], metric=metric, n_neighbors=10, random_state=None, compressed=False, ) nnd.prepare() nnd.update(xs_fresh=nn_data[800:]) nnd.prepare() knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2) true_nnd = NearestNeighbors(metric=metric).fit(nn_data[200:]) true_indices = true_nnd.kneighbors(nn_data[:200], 10, return_distance=False) num_correct = 0.0 for i in range(true_indices.shape[0]): num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i])) percent_correct = num_correct / (true_indices.shape[0] * 10) assert percent_correct >= 0.95, ("NN-descent query did not get 95% " "accuracy on nearest neighbors")
def test_tree_numbers_after_multiple_updates(n_trees): trees_after_update = max(1, int(np.round(n_trees / 3))) nnd = NNDescent(np.array([[1.0]]), n_neighbors=1, n_trees=n_trees) assert nnd.n_trees == n_trees, "NN-descent update changed the number of trees" assert ( nnd.n_trees_after_update == trees_after_update ), "The value of the n_trees_after_update in NN-descent after update(s) is wrong" for i in range(5): nnd.update(xs_fresh=np.array([[i]], dtype=np.float64)) assert ( nnd.n_trees == trees_after_update ), "The value of the n_trees in NN-descent after update(s) is wrong" assert ( nnd.n_trees_after_update == trees_after_update ), "The value of the n_trees_after_update in NN-descent after update(s) is wrong"
"accuracy on nearest neighbors") k = 10 xs_orig, xs_fresh, xs_updated, indices_updated = update_data[case] queries1 = xs_orig # original index = NNDescent(xs_orig, metric=metric, n_neighbors=40, random_state=1234) index.prepare() evaluate(index, xs_orig, queries1) # updated index.update(xs_fresh=xs_fresh, xs_updated=xs_updated, updated_indices=indices_updated) if xs_fresh is not None: xs = np.vstack((xs_orig, xs_fresh)) queries2 = np.vstack((queries1, xs_fresh)) else: xs = xs_orig queries2 = queries1 if indices_updated is not None: xs[indices_updated] = xs_updated evaluate(index, xs, queries2) if indices_updated is not None: evaluate(index, xs, xs_updated) @pytest.mark.parametrize("n_trees", [1, 2, 3, 10])