KNN

KNN

Normal KNN

KNN(K-Nearest Neighbor) is a extremely simple classification algorithm.

For the normal KNN algorithm, the model does not have an explicit trainining process. For any test samples, the algorithm calculate distance with each 'training samples' (\(L_2\) distance , Euclidean distance), and then select the first \(k\) samples according to the distance. The most classes in the \(k\) samples will be the label of the test sample.

For a smaller \(k\), the model will become complexity, and easily overfitting. Increasing the value of \(k\) means the model becomes simpler.

kd Tree

A simple approach to select the k samples is do linear scan in the training samples. However, this process will become extremely computing expensice when the feature dimension and the data are huge. Therefore, to reduce the computation, we can construct a kd-tree to store the training data.

kd-tree is a kind of binary tree. When splitting the samples, the algorithm uses the median value as the splitting point. Here is the code of constructing the kd-tree:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def get_split_node(self, samples, index, cur_node):
samples = sorted(samples, key=lambda x: x[index])
num = len(samples)

node = TreeNone()
node.parent_node = cur_node
node.data = samples[int((num+1)/2)-1] if num % 2 != 0 else samples[int(num/2)-1]

return node

def construct_kd_tree(self, training_samples, cur_node=None, depth=0):
index = (depth % len(training_samples[0]))

# 计算切分节点
split_node = self.get_split_node(training_samples, index, cur_node)

# 切分左右子树
left_subtree = [x for x in training_samples if x[index] < split_node.data[index]]
right_subtree = [x for x in training_samples if x[index] > split_node.data[index]]

# 继续切分左右子树
cur_node = split_node
cur_node.depth = depth
if len(left_subtree) != 0:
if len(left_subtree) == 1:
node = TreeNone()
node.data = left_subtree[0]
node.parent_node = cur_node
node.is_leaf = True
node.depth = depth + 1
cur_node.left_subnode = node
else:
cur_node.left_subnode = self.construct_kd_tree(left_subtree, cur_node, depth+1)

if len(right_subtree) != 0:
if len(right_subtree) == 1:
node = TreeNone()
node.data = right_subtree[0]
node.parent_node = cur_node
node.is_leaf = True
node.depth = depth + 1
cur_node.right_subnode = node
else:
cur_node.right_subnode = self.construct_kd_tree(right_subtree, cur_node, depth+1)

return cur_node

Searching on the kd-tree is a little bit complex :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def search_nodes(self, root, test_data):
def calculate_distance(p1, p2):
return np.sum(np.square(p1-p2))

def get_max_distance(k_list, test_data):
max_dis = np.max([np.sum(np.square(p-test_data)) for p in k_list])
max_index = np.argmax([np.sum(np.square(p-test_data)) for p in k_list])
return max_dis, max_index

def update_list(k_list, new_point, test_data):
max_dis, max_index = get_max_distance(k_list, test_data)
new_dis = calculate_distance(new_point, test_data)
if new_dis < max_dis:
del k_list[max_index]
k_list.append(new_point)

def find_leaf(subroot, k_list):
node = subroot
while True:
if node.is_leaf:
if len(k_list) < 3:
k_list.append(node.data)
else:
update_list(k_list, node.data, test_data)
node.is_calculated = True
break
index = node.depth % len(node.data)
if test_data[index] < node.data[index]:
node = node.left_subnode
else:
node = node.right_subnode

return node

k_list = []
node = find_leaf(root, k_list)

while node.parent_node != None:
node = node.parent_node
if node.is_calculated:
continue

if len(k_list) < 3:
k_list.append(node.data)
else:
update_list(k_list, node.data, test_data)
node.is_calculated = True

# 寻找当前节点下未被访问过的分支
# 搜索过程是从子节点向上回溯,因此最多只可能有一个未访问的分支
branch = None
if (node.left_subnode != None and node.left_subnode.is_calculated == False):
brach = node.left_subnode
elif (node.right_subnode != None and node.right_subnode.is_calculated == False):
branch = node.right_subnode
# 如果当前节点没有未访问过的分支 或 预测点到切分线距离不小于所有候选点,不用继续向下寻找
if branch == None:
continue
# 计算当前节点的切分线
index = node.depth % len(node.data)
split_line = np.zeros((node.data.shape[-1]))
split_line[index] = node.data[index]
point2line = calculate_distance(test_data, split_line)
max_dis, _ = get_max_distance(k_list, test_data)
if point2line >= max_dis:
continue
else:
# 再次找到当前分支下符合条件的叶节点,同时更新候选列表
node = find_leaf(branch, k_list)

print(k_list)

The whole codes are submitted on the github (kd-tree codes)

0%