kd树python实现
1. 首先在构造kd树的时需要寻找中位数,因此用快速排序来获取一个list中的中位数
import matplotlib.pyplot as plt
import numpy as np
class QuickSort(object):
"Quick Sort to get medium number"
def __init__(self, low, high, array):
self._array = array
self._low = low
self._high = high
self._medium = (low+high+1)//2 # python3中的整除
def get_medium_num(self):
return self.quick_sort_for_medium(self._low, self._high,
self._medium, self._array)
def quick_sort_for_medium(self, low, high, medium, array): #用快速排序来获取中位数
if high == low:
return array[low] # find medium
if high > low:
index, partition = self.sort_partition(low, high, array);
#print array[low:index], partition, array[index+1:high+1]
if index == medium:
return partition
if index > medium:
return self.quick_sort_for_medium(low, index-1, medium, array)
else:
return self.quick_sort_for_medium(index+1, high, medium, array)
def quick_sort(self, low, high, array): #正常的快排
if high > low:
index, partition = self.sort_partition(low, high, array);
#print array[low:index], partition, array[index+1:high+1]
self.quick_sort(low, index-1, array)
self.quick_sort(index+1, high, array)
def sort_partition(self, low, high, array): # 用第一个数将数组里面的数分成两部分
index_i = low
index_j = high
partition = array[low]
while index_i < index_j:
while (index_i < index_j) and (array[index_j] >= partition):
index_j -= 1
if index_i < index_j:
array[index_i] = array[index_j]
index_i += 1
while (index_i < index_j) and (array[index_i] < partition):
index_i += 1
if index_i < index_j:
array[index_j] = array[index_i]
index_j -= 1
array[index_i] = partition
return index_i, partition2. 构造kd树
测试代码:
3. 搜索kd树
在类中继续添加如下函数,基本的思想是将路径上的节点依次入栈,再逐个出栈。
测试代码:
4. 寻找k个最近节点
如果要寻找k个最近节点,则需要保存k个元素的数组,并在函数_check_nearest中与k个元素做比较,然后在标记<*>的地方跟k个元素的最大值比较。其他代码略。
测试代码:
5.其他
这里的算法没有考虑到下面的情况:
多个数据点在同一个超平面上
有多个数据点跟目标节点的距离相同
Last updated