Example #1
0
    def maintainHeap(self, Node, findPoint, k):

        cur_dist = get_distance(Node.point, findPoint)
        if len(self.hp_knn) < k:
            hq.heappush(self.hp_knn, list(Node.point))
        elif len(self.hp_knn) >= k and cur_dist < get_distance(
                findPoint, self.hp_knn[0]):
            hq.heappop(self.hp_knn)
            hq.heappush(self.hp_knn, list(Node.point))
Example #2
0
    def getLeaf(self, findPoint):
        leaf = self.root
        next = None
        while leaf.left or leaf.right:
            if findPoint[leaf.split] < leaf.point[leaf.split]:
                next = leaf.left
            elif findPoint[leaf.split] > leaf.point[leaf.split]:
                next = leaf.right
            else:
                if get_distance(findPoint, leaf.left.point) < get_distance(
                        findPoint, leaf.right.point):
                    next = leaf.left
                else:
                    next = leaf.right

            if next:
                leaf = next

        return leaf
Example #3
0
    def findNearest(self, findPoint):

        # NP存储离findPoint最近的点,min_dist存储最近距离
        NP = self.root.point
        min_dist = get_distance(NP, findPoint)

        temp_root = self.root
        nodeList = []  # 存储经过的结点
        while temp_root:

            nodeList.append(temp_root)
            dd = get_distance(temp_root.point, findPoint)
            if dd < min_dist:
                NP = temp_root.point
                min_dist = dd
            # 获取当前结点分割域
            split = temp_root.split
            if findPoint[split] <= temp_root.point[split]:
                temp_root = temp_root.left
            else:
                temp_root = temp_root.right

        # 对经历过的结点进行回溯检查
        while nodeList:
            back_root = nodeList.pop()
            split = back_root.split
            if abs(findPoint[split] - back_root.point[split]) < min_dist:
                if findPoint[split] < back_root.point[split]:
                    temp_root = back_root.right
                else:
                    temp_root = back_root.left

            if temp_root:
                nodeList.append(temp_root)
                dd = get_distance(findPoint, temp_root.point)
                if dd < min_dist:
                    min_dist = dd
                    NP = temp_root.point

        return NP, min_dist
Example #4
0
def classify(x, test_data, labels, k):
    dis = get_distance(x, test_data)
    dis_indices = dis.argsort()  # sort the distances and get the indices

    # use a dictionary to record the first k points' label
    labelRecord = {}

    # get the label of x according to the first k points
    for i in range(k):
        test_label = labels[dis_indices[i]]
        labelRecord[test_label] = labelRecord.get(test_label, 0) + 1

    # sort the dict according to the value
    sortedRecord = sorted(labelRecord.iteritems(),
                          key=operator.itemgetter(1),
                          reverse=True)

    return sortedRecord[0][0]
Example #5
0
    def findKNN(self, findPoint, k):

        nearestNode = self.getLeaf(findPoint)

        while nearestNode:
            cur_dist = get_distance(nearestNode.point, findPoint)
            self.maintainHeap(nearestNode, findPoint, k)
            if nearestNode.parent and cur_dist > abs(
                    findPoint[nearestNode.parent.split] -
                    nearestNode.parent.point[nearestNode.parent.split]):
                if nearestNode.point[
                        nearestNode.parent.split] <= nearestNode.parent.point[
                            nearestNode.parent.split]:
                    brotherNode = nearestNode.parent.right
                else:
                    brotherNode = nearestNode.parent.left
                # 兄弟结点存在
                if brotherNode:
                    self.maintainHeap(brotherNode, findPoint, k)

            nearestNode = nearestNode.parent
        return