CST PA3
3-6 Nearest Neighbour
算法构思
本题需使用 KD-Tree
数据结构.
数据结构
// Line 16
struct Node {
public:
int d[5];
bool friend operator<(Node x, Node y) {
return x.d[comp] < y.d[comp];
}
} node[N], p;
使用全局变量 int dimension
记录操作向量的维数, Node
结构体支持存储一个 $k$ 维向量, 其中 $2\le k\le 5$.
// Line 26
struct Tree {
public:
int v[5][2];
Node p;
bool leftnode = false;
bool rightnode = false;
bool is_leaf = false;
} tree[4 * N];
每个 Tree
结构体实例对应 KD-Tree
中的一个节点, leftnode
与 rightnode
记录了其是否存在左、右子节点, is_leaf
记录其是否为叶节点.
每个 Tree
实例内部保存了一个向量 Node p
, 而 int v[5][2]
记录了其对应的 KD-Tree
节点集在 $k$ 维空间内覆盖的范围, 其中 v[i][0]
, v[i][1]
分别记录了 Tree
节点在空间中第 $i$ 维的范围.
Euclid 距离
// Line 36
ll Dist(Node x, Node y){
ll dis = 0;
for (int i = 0; i < dimension; i++)
dis += 1ll * (x.d[i] - y.d[i]) * (x.d[i] - y.d[i]);
return dis;
}
首先为两个向量定义其 $Euclid$ 距离.
// Line 44
ll Dist(Node x, Tree y){
ll dis = 0;
int dist[5];
for (int i = 0; i < dimension; i++) {
if (x.d[i] < y.v[i][0]) dist[i] = y.v[i][0] - x.d[i];
else if (x.d[i] > y.v[i][1]) dist[i] = x.d[i] - y.v[i][1];
else dist[i] = 0;
}
for (int i = 0; i < dimension; i++)
dis += 1ll * dist[i] * dist[i];
return dis;
}
再为一个点与一个区域定义其最短 $Euclid$ 距离. 逐次比较该点第 $i$ 维的坐标与区域在第 $i$ 维的坐标范围, 并取最短距离. 将每个维度的距离综合即得到了该点到此区域的最短 $Euclid$ 距离.
中位排序
我们定义一个中位排序操作, 它将序列的中位数放置到正确的位置, 同时满足前半段的元素均小于中位元素, 后半段的元素均大于中位元素.
该算法具体实现参考自此教程, 是一种期望为线性时间的选择算法, 时间复杂度在平均情况下为 $O(n)$.
// Line 58
void Swap(int l, int r) {
Node tmp = node[l];
node[l] = node[r];
node[r] = tmp;
}
int Partion(int l, int r) {
int k = l;
for (int i = l; i < r; ++i)
if (node[i] < node[r]) {
Swap(k, i);
k++;
}
Swap(k, r);
return k;
}
int FindMedium(int l, int r) {
int mid = (l + r) >> 1;
while (true) {
int k = Partion(l, r);
if (k == mid) return k;
else if (k > mid) r = k - 1;
else l = k + 1;
}
}
每次 Partion()
操作的执行都能确定一个元素的正确位置 (若该元素为第 $k$ 大, 那么其一定处于第 $k$ 位), 同时其左的元素均小于该元素, 其右的元素均大于该元素. 如果 Partion
返回的位置恰为序列的中位, 那么中位排序已完成; 否则递归对剩余元素进行中位排序即可.
建树
// Line 104
void Build(int k, int l, int r, int direction) {
int mid = (l + r) >> 1;
if(l == r) {
tree[k].p = node[mid];
for(int i = 0; i < dimension; i++) {
tree[k].v[i][0] = node[mid].d[i];
tree[k].v[i][1] = node[mid].d[i];
}
tree[k].is_leaf = true;
return;
}
comp = direction;
FindMedium(l, r);
tree[k].p = node[mid];
if (l <= mid - 1) {
tree[k].leftnode = true;
Build(left_son(k), l, mid - 1, (direction + 1) % dimension);
}
tree[k].rightnode = true;
Build(right_son(k), mid + 1, r, (direction + 1) % dimension);
PushUp(k);
}
其中全局变量 int comp
的取值范围为 $[0, dimension - 1]$, 它记录了当前建树时从哪个维度的坐标对当前向量进行排序.
一个 tree[k]
节点可能会对应多个 Node
向量, 我们调用 FindMedium(l, r)
对这些向量排序, 将按第 comp
维度的坐标排序后的中位向量放置在 Node node[N]
中位 node[mid]
上, 同时满足在其之前的向量的第 comp
维度坐标均小于 node[mid]
, 在其之后的向量的第 comp
维度坐标均大于 node[mid]
.
随后将 node[mid]
记录在 tree[k]
中, 并递归建立其左右子树, 在 tree[k]
内使用 bool leftnode
与 bool rightnode
记录, 最后调用 PushUp(k)
自底向顶更新每个 tree[k]
内的节点集所覆盖的区域的端点.
更新
对于叶节点, 我们在 Build()
中将其覆盖的区域端点设置为该点本身:
...
// Line 107
if(l == r) { // 叶节点
tree[k].p = node[mid];
for(int i = 0; i < dimension; i++) {
tree[k].v[i][0] = node[mid].d[i];
tree[k].v[i][1] = node[mid].d[i];
}
tree[k].is_leaf = true;
return;
}
...
对于含有子节点的父节点 tree[k]
, 根据我们建树的方式, 其右子节点 tree[right]
必然存在. 首先在右子节点 tree[right]
的覆盖区域内加入点 tree[k].p
; 若左子节点 tree[left]
存在, 再使用其覆盖区域更新 tree[k]
的覆盖区域.
// Line 89
void PushUp(int k) {
int left = left_son(k);
int right = right_son(k);
for(int i = 0; i < dimension; i++) {
tree[k].v[i][0] = min(tree[right].v[i][0], tree[k].p.d[i]);
tree[k].v[i][1] = max(tree[right].v[i][1], tree[k].p.d[i]);
}
if (tree[k].leftnode == true)
for(int i = 0; i < dimension; i++) {
tree[k].v[i][0] = min(tree[k].v[i][0], tree[left].v[i][0]);
tree[k].v[i][1] = max(tree[k].v[i][1], tree[left].v[i][1]);
}
}
查询
// Line 134
void Query(Node p, int k) {
ANS = min(ANS, Dist(p, tree[k].p));
if (tree[k].is_leaf) return;
if (tree[k].leftnode == true) {
ll left_min = Dist(p, tree[left_son(k)]);
ll right_min = Dist(p, tree[right_son(k)]);
if (left_min < right_min) ...
else ...
}
else
if (Dist(p, tree[right_son(k)]) < ANS)
Query(p, right_son(k));
}
给出点 Node p
与子树 tree[k]
, 我们给出 p
到 tree[k]
中的点的最短距离. 其中 ANS
是一个记录当前查询到的最短距离的全局变量, 在每次总查询前, 我们将其复位为 INF = 1e16
.
首先使用 tree[k]
内记录的点与 p
的距离更新最短距离 ANS
; 若左子节点不存在, 特判 p
到 tree[right]
区域的最小值, 当小于 ANS
时执行向下递归查询; 若左子节点存在, 那么同时计算 p
到 tree[left]
区域的最小值, 并首先递归查询距离较短的子树.
输入与查询
封装为 Init()
与 Work()
函数.
问题解决
- 起初出现了
Runtime Error(trap 14)
的报错, 在对Tree[]
进行扩容后消除, 这是因为对 $n$ 个点建立KD-Tree
需要不止 $n$ 个Tree
节点来存储. - 使用全局变量
comp
记录当前建树所选取坐标轴的维度, 并在向下递归过程中将其+ 1
. - 首次实现时, 我将所有向量都保存叶节点内, 这样做会导致第 $13$ 个点开始出现
Time Limit Exceeded
. 再次实现时, 我在每个内部节点也保存了一个向量值, 并在Query()
过程向左右子树递归时, 首先访问距离较短的那一个, 起到了剪枝优化的效果. - 对于
FindMedium()
的实现参考了 此教程.
复杂度估计
时间复杂度
Init()
过程读入输入数据的时间复杂度为 $O(nd)$, 建树过程中每一步时间消耗主要源自 FindMedium(l, r)
, 这是 $O(n)$ 复杂度的, 逐层建树的时间复杂度为 $O(nlogn)$.
每次进行 Query()
都进行了剪枝操作, 优化了访问不可能包含答案点的区域的情形. 考虑最坏情况下查询的时间复杂度, 与遍历一棵 $d$ 维 $KD-Tree$ 是等价的, 时间复杂度为 $O(n^{1 - 1/d})$. 进行 $m$ 次查询的时间复杂度为 $O(mn^{1 - 1/d})$.
综上, 算法总时间复杂度为 $O(nlogn + mn^{1 - 1/d})$.
空间复杂度
算法空间复杂度主要来自读取并存储数据的过程:
每个 Node
结构体实例占用的空间为: $5\times 4B = 20B$.
每个 Tree
结构体实例占用的空间为: $10\times 4B + 20B + 4B = 64B$.
实际占用的 Node
, Tree
结构体数目正比于输入数据的规模, 空间复杂度为 $O(nd)$.
最坏情况下程序所占用的内存约为: $20B\times 10^5 + 64B\times 4\times 10^5 = 26.3MB \ll 256MB$.
因此空间复杂度满足要求.