kd树python实现

1. 首先在构造kd树的时需要寻找中位数,因此用快速排序来获取一个list中的中位数

import matplotlib.pyplot as plt
import numpy as np

class QuickSort(object):
    "Quick Sort to get medium number"

    def __init__(self, low, high, array):
        self._array = array
        self._low = low
        self._high = high
        self._medium = (low+high+1)//2 # python3中的整除

    def get_medium_num(self):
        return self.quick_sort_for_medium(self._low, self._high, 
                                          self._medium, self._array)

    def quick_sort_for_medium(self, low, high, medium, array): #用快速排序来获取中位数
        if high == low:
            return array[low] # find medium
        if high > low:
            index, partition = self.sort_partition(low, high, array); 
            #print array[low:index], partition, array[index+1:high+1]
            if index == medium:
                return partition
            if index > medium:
                return self.quick_sort_for_medium(low, index-1, medium, array)
            else:
                return self.quick_sort_for_medium(index+1, high, medium, array)

    def quick_sort(self, low, high, array):  #正常的快排
        if high > low:
            index, partition = self.sort_partition(low, high, array); 
            #print array[low:index], partition, array[index+1:high+1]
            self.quick_sort(low, index-1, array)
            self.quick_sort(index+1, high, array)

    def sort_partition(self, low, high, array): # 用第一个数将数组里面的数分成两部分
        index_i = low
        index_j = high
        partition = array[low]
        while index_i < index_j:
            while (index_i < index_j) and (array[index_j] >= partition):
                index_j -= 1
            if index_i < index_j:
                array[index_i] = array[index_j]
                index_i += 1
            while (index_i < index_j) and (array[index_i] < partition):
                index_i += 1
            if index_i < index_j:
                array[index_j] = array[index_i]
                index_j -= 1
        array[index_i] = partition
        return index_i, partition

2. 构造kd树

测试代码:

3. 搜索kd树

在类中继续添加如下函数,基本的思想是将路径上的节点依次入栈,再逐个出栈。

测试代码:

4. 寻找k个最近节点

如果要寻找k个最近节点,则需要保存k个元素的数组,并在函数_check_nearest中与k个元素做比较,然后在标记<*>的地方跟k个元素的最大值比较。其他代码略。

测试代码:

5.其他

这里的算法没有考虑到下面的情况:

  • 多个数据点在同一个超平面上

  • 有多个数据点跟目标节点的距离相同

Last updated