示例#1
0
class HotelTree:

    tree_keys = ["lon", "lat", "price"]
    return_keys = ["lon", "lat", "name", "stars", "price"]

    def __init__(self, base_df: pd.DataFrame):

        self.coords = np.array(base_df[self.tree_keys])
        self.arr = np.array(base_df[self.return_keys])
        self.tree = KDTree(self.coords, leaf_size=10)

    def query_w_filter(self, input_list):

        input_points = []
        input_filters = []
        for indic in input_list:
            indic["price"] = (indic["min_price"] + indic["max_price"]) / 2
            input_filters.append([-1, -1, indic["max_price"] - indic["price"]])
            input_points.append([indic[k] for k in self.tree_keys])

        dist, _ind = self.tree.query_filtered(np.array(input_points),
                                              np.array(input_filters))

        ind = [i[0] for i in _ind]

        return [{k: a[i]
                 for i, k in enumerate(self.return_keys)} if d < np.inf else {
                     "missing": True
                 } for a, (d, ) in zip(self.arr[ind, :], dist)]
示例#2
0
def test_kd_tree_filter_query():

    X = np.array([
        [0, 2, 1],
        [2, 2, 3],
        [10, 3, 1],
        [4, 5, 6],
        [0, 0, 0],
        [-10, -3, 2],
        [1, -5, 2],
    ])
    tree = KDTree(X, leaf_size=1)
    queries = np.array([[1, 1, 1]])

    dist, ind = tree.query_filtered(queries, [[1, -1, -1]], k=2)

    assert ind[0][0] == 0
    assert ind[0][1] == 4
    assert dist[0][0] == 1

    rand = np.random.RandomState(18)

    n_features = 6
    data_size = 2000
    query_num = 20

    X = rand.rand(data_size, n_features)

    tree2 = KDTree(X, leaf_size=4)

    queries = rand.rand(query_num, n_features)
    filter_radiuses = rand.rand(query_num, n_features) - 0.5

    bruteforce_ind = []
    bruteforce_dist = []
    for q, rads in zip(queries, filter_radiuses):
        effective_rads = np.array([r if r >= 0 else np.inf for r in rads])
        _dist_inds = rads < 0
        min_rdist = np.inf
        min_ind = 0
        for i, x_i in enumerate(X):
            if (np.abs(x_i - q) <= effective_rads).all():
                _d = ((x_i[_dist_inds] - q[_dist_inds])**2).sum()
                if _d < min_rdist:
                    min_rdist = _d
                    min_ind = i
        bruteforce_ind.append([min_ind])
        bruteforce_dist.append([np.sqrt(min_rdist)])

    bf_dist = np.array(bruteforce_dist)
    bf_ind = np.array(bruteforce_ind)

    dist, ind = tree2.query_filtered(queries, filter_radiuses, k=1)

    for q, rad, i, bi, d, bd in zip(
            queries,
            filter_radiuses,
            ind[:, 0],
            bf_ind[:, 0],
            dist[:, 0],
            bf_dist[:, 0],
    ):
        if (rad < 0).any():
            assert i == bi
        assert d == bd