示例#1
0
def main():
    print("Testing KD Tree...")
    test_times = 100
    run_time_1 = run_time_2 = 0

    for _ in range(test_times):
        # 随机生成数据
        low = 0
        high = 100
        n_rows = 1000
        n_cols = 2
        X = gen_data(low, high, n_rows, n_cols)
        y = gen_data(low, high, n_rows)
        Xi = gen_data(low, high, n_cols)

        # 创建Kd树
        tree = KDTree()
        tree.build_tree(X, y)

        # Kd树查找
        start = time()
        nd = tree.nearest_neighbour_search(Xi)
        run_time_1 += time() - start
        ret1 = get_eu_dist(Xi, nd.split[0])

        # 普通线性查找
        start = time()
        row = exhausted_search(X, Xi)
        run_time_2 += time() - start
        ret2 = get_eu_dist(Xi, row)

        # 比较结果
        assert ret1 == ret2, "target:%s\nrestult1:%s\nrestult2:%s\ntree:\n%s" % (
            Xi, nd, row, tree)

    print("%d tests passed!" % test_times)
    print("KD Tree Search %.2f s" % run_time_1)
    print("Exhausted search %.2f s" % run_time_2)
示例#2
0
    #      [3,4,6,1],
    #      [3,6,6,1],
    #      [3,5,9,2],
    #      [3,5,12,2],
    #      [3,5,13,2],]

    X = [[2, 3, 1], [5, 4, 1], [9, 6, 1], [8.5, 6, 1], [4, 7, 1], [8, 1, 1],
         [7, 2, 1]]

    X = np.array(X)

    # data_train = pd.read_csv('./data_set/iris_1.csv', header=0)
    # train_data = np.array(data_train)

    # X = train_data[:, :-1]
    # y = train_data[:, -1]

    # X_train, X_test, y_train, y_true = train_test_split(X, y,test_size=1 / 3., random_state=6)
    #
    # train_set = np.column_stack((X_train, y_train))

    kd = KDTree()
    kd.build_tree(X)

    x = [[7, 6, 1], [3, 4.5, 1]]
    test_x = np.array(x)
    # print(test_x[:,:-1])
    nearest = kd.search_neighbour(test_x)
    for i in range(len(test_x)):
        print(test_x[i], '--->', nearest[i])
示例#3
0
from data import Data
from kd_tree import KDTree

kd = KDTree(2)
d = Data()
d.extract('example.dat')
print d.data
kd.build_tree(d.data.keys()[:7])
print kd
print kd.nearest((7,7), kd.root)
#print kd