def test_find_self(): X_train, X_test = _get_mnist_data() for no_trees, expected_precision in ((1, 0.05), (5, 0.3), (10, 0.5), (50, 0.9)): tree = RPForest(leaf_size=10, no_trees=no_trees) tree.fit(X_train) nodes = {k: set(v) for k, v in tree.get_leaf_nodes()} for i, x_train in enumerate(X_train): nns = tree.query(x_train, 10)[:10] assert nns[0] == i point_codes = tree.encode(x_train) for code in point_codes: assert i in nodes[code] tree = pickle.loads(pickle.dumps(tree)) nodes = {k: set(v) for k, v in tree.get_leaf_nodes()} for i, x_train in enumerate(X_train): nns = tree.query(x_train, 10)[:10] assert nns[0] == i point_codes = tree.encode(x_train) for code in point_codes: assert i in nodes[code]
def test_clear(): X_train, X_test = _get_mnist_data() tree = RPForest(leaf_size=10, no_trees=10) tree.fit(X_train) for leaf_code, leaf_indices in tree.get_leaf_nodes(): assert leaf_indices tree.clear() for leaf_code, leaf_indices in tree.get_leaf_nodes(): assert not leaf_indices