Nearest Neighbour


CST PA3

3-6 Nearest Neighbour

算法构思

本题需使用 KD-Tree 数据结构.

数据结构

// Line 16
struct Node {
public:
    int d[5];
    bool friend operator<(Node x, Node y) {
        return x.d[comp] < y.d[comp];
    }
} node[N], p;

使用全局变量 int dimension 记录操作向量的维数, Node 结构体支持存储一个 $k$ 维向量, 其中 $2\le k\le 5$.

// Line 26
struct Tree {
public:
    int v[5][2];
    Node p;
    bool leftnode = false;
    bool rightnode = false;
    bool is_leaf = false;
} tree[4 * N];

每个 Tree 结构体实例对应 KD-Tree 中的一个节点, leftnoderightnode 记录了其是否存在左、右子节点, is_leaf 记录其是否为叶节点.

每个 Tree 实例内部保存了一个向量 Node p, 而 int v[5][2] 记录了其对应的 KD-Tree 节点集在 $k$ 维空间内覆盖的范围, 其中 v[i][0], v[i][1] 分别记录了 Tree 节点在空间中第 $i$ 维的范围.

Euclid 距离

// Line 36
ll Dist(Node x, Node y){
    ll dis = 0;
    for (int i = 0; i < dimension; i++)
        dis += 1ll * (x.d[i] - y.d[i]) * (x.d[i] - y.d[i]);
    return dis;
}

首先为两个向量定义其 $Euclid$ 距离.

// Line 44
ll Dist(Node x, Tree y){
    ll dis = 0;
    int dist[5];
    for (int i = 0; i < dimension; i++) {
        if (x.d[i] < y.v[i][0]) dist[i] = y.v[i][0] - x.d[i];
        else if (x.d[i] > y.v[i][1]) dist[i] = x.d[i] - y.v[i][1];
        else dist[i] = 0;
    }
    for (int i = 0; i < dimension; i++)
        dis += 1ll * dist[i] * dist[i];
    return dis;
}

再为一个点与一个区域定义其最短 $Euclid$ 距离. 逐次比较该点第 $i$ 维的坐标与区域在第 $i$ 维的坐标范围, 并取最短距离. 将每个维度的距离综合即得到了该点到此区域的最短 $Euclid$ 距离.

中位排序

我们定义一个中位排序操作, 它将序列的中位数放置到正确的位置, 同时满足前半段的元素均小于中位元素, 后半段的元素均大于中位元素.

该算法具体实现参考自此教程, 是一种期望为线性时间的选择算法, 时间复杂度在平均情况下为 $O(n)$.

// Line 58
void Swap(int l, int r) {
    Node tmp = node[l];
    node[l] = node[r];
    node[r] = tmp;
}

int Partion(int l, int r) {
    int k = l;
    for (int i = l; i < r; ++i)
        if (node[i] < node[r]) {
            Swap(k, i);
            k++;
        }
    Swap(k, r);
    return k;
}

int FindMedium(int l, int r) {
    int mid = (l + r) >> 1;
    while (true) {
        int k = Partion(l, r);
        if (k == mid) return k;
        else if (k > mid) r = k - 1;
        else l = k + 1;
    }
}

每次 Partion() 操作的执行都能确定一个元素的正确位置 (若该元素为第 $k$ 大, 那么其一定处于第 $k$ 位), 同时其左的元素均小于该元素, 其右的元素均大于该元素. 如果 Partion 返回的位置恰为序列的中位, 那么中位排序已完成; 否则递归对剩余元素进行中位排序即可.

建树

// Line 104
void Build(int k, int l, int r, int direction) {
    int mid = (l + r) >> 1;
    if(l == r) {
        tree[k].p = node[mid];
        for(int i = 0; i < dimension; i++) {
            tree[k].v[i][0] = node[mid].d[i];
            tree[k].v[i][1] = node[mid].d[i];
        }
        tree[k].is_leaf = true;
        return;
    }
    comp = direction;
    FindMedium(l, r);
    tree[k].p = node[mid];
    if (l <= mid - 1) {
        tree[k].leftnode = true;
        Build(left_son(k), l, mid - 1, (direction + 1) % dimension);
    }
    tree[k].rightnode = true;
    Build(right_son(k), mid + 1, r, (direction + 1) % dimension);
    PushUp(k);
}

其中全局变量 int comp 的取值范围为 $[0, dimension - 1]$, 它记录了当前建树时从哪个维度的坐标对当前向量进行排序.

一个 tree[k] 节点可能会对应多个 Node 向量, 我们调用 FindMedium(l, r) 对这些向量排序, 将按第 comp 维度的坐标排序后的中位向量放置在 Node node[N] 中位 node[mid] 上, 同时满足在其之前的向量的第 comp 维度坐标均小于 node[mid] , 在其之后的向量的第 comp 维度坐标均大于 node[mid].

随后将 node[mid] 记录在 tree[k] 中, 并递归建立其左右子树, 在 tree[k] 内使用 bool leftnodebool rightnode 记录, 最后调用 PushUp(k) 自底向顶更新每个 tree[k] 内的节点集所覆盖的区域的端点.

更新

对于叶节点, 我们在 Build() 中将其覆盖的区域端点设置为该点本身:

...
// Line 107
if(l == r) { // 叶节点
    tree[k].p = node[mid];
    for(int i = 0; i < dimension; i++) {
        tree[k].v[i][0] = node[mid].d[i];
        tree[k].v[i][1] = node[mid].d[i];
    }
    tree[k].is_leaf = true;
    return;
}
...

对于含有子节点的父节点 tree[k], 根据我们建树的方式, 其右子节点 tree[right] 必然存在. 首先在右子节点 tree[right] 的覆盖区域内加入点 tree[k].p; 若左子节点 tree[left] 存在, 再使用其覆盖区域更新 tree[k] 的覆盖区域.

// Line 89
void PushUp(int k) {
    int left = left_son(k);
    int right = right_son(k);
    for(int i = 0; i < dimension; i++) {
        tree[k].v[i][0] = min(tree[right].v[i][0], tree[k].p.d[i]);
        tree[k].v[i][1] = max(tree[right].v[i][1], tree[k].p.d[i]);
    }
    if (tree[k].leftnode == true)
        for(int i = 0; i < dimension; i++) {
            tree[k].v[i][0] = min(tree[k].v[i][0], tree[left].v[i][0]);
            tree[k].v[i][1] = max(tree[k].v[i][1], tree[left].v[i][1]);
        }
}

查询

// Line 134
void Query(Node p, int k) {
    ANS = min(ANS, Dist(p, tree[k].p));
    if (tree[k].is_leaf) return;
    if (tree[k].leftnode == true) {
        ll left_min = Dist(p, tree[left_son(k)]);
        ll right_min = Dist(p, tree[right_son(k)]);
        if (left_min < right_min) ...
        else ...
    }
    else
        if (Dist(p, tree[right_son(k)]) < ANS)
            Query(p, right_son(k));
}

给出点 Node p 与子树 tree[k], 我们给出 ptree[k] 中的点的最短距离. 其中 ANS 是一个记录当前查询到的最短距离的全局变量, 在每次总查询前, 我们将其复位为 INF = 1e16.

首先使用 tree[k] 内记录的点与 p 的距离更新最短距离 ANS; 若左子节点不存在, 特判 ptree[right] 区域的最小值, 当小于 ANS 时执行向下递归查询; 若左子节点存在, 那么同时计算 ptree[left] 区域的最小值, 并首先递归查询距离较短的子树.

输入与查询

封装为 Init()Work() 函数.

问题解决

  • 起初出现了 Runtime Error(trap 14) 的报错, 在对 Tree[] 进行扩容后消除, 这是因为对 $n$ 个点建立 KD-Tree 需要不止 $n$ 个 Tree 节点来存储.
  • 使用全局变量 comp 记录当前建树所选取坐标轴的维度, 并在向下递归过程中将其 + 1.
  • 首次实现时, 我将所有向量都保存叶节点内, 这样做会导致第 $13$ 个点开始出现 Time Limit Exceeded. 再次实现时, 我在每个内部节点也保存了一个向量值, 并在 Query() 过程向左右子树递归时, 首先访问距离较短的那一个, 起到了剪枝优化的效果.
  • 对于 FindMedium() 的实现参考了 此教程.

复杂度估计

时间复杂度

Init() 过程读入输入数据的时间复杂度为 $O(nd)$, 建树过程中每一步时间消耗主要源自 FindMedium(l, r), 这是 $O(n)$ 复杂度的, 逐层建树的时间复杂度为 $O(nlogn)$.

每次进行 Query() 都进行了剪枝操作, 优化了访问不可能包含答案点的区域的情形. 考虑最坏情况下查询的时间复杂度, 与遍历一棵 $d$ 维 $KD-Tree$ 是等价的, 时间复杂度为 $O(n^{1 - 1/d})$. 进行 $m$ 次查询的时间复杂度为 $O(mn^{1 - 1/d})$.

综上, 算法总时间复杂度为 $O(nlogn + mn^{1 - 1/d})$.

空间复杂度

算法空间复杂度主要来自读取并存储数据的过程:

每个 Node 结构体实例占用的空间为: $5\times 4B = 20B$.

每个 Tree 结构体实例占用的空间为: $10\times 4B + 20B + 4B = 64B$.

实际占用的 Node, Tree 结构体数目正比于输入数据的规模, 空间复杂度为 $O(nd)$.

最坏情况下程序所占用的内存约为: $20B\times 10^5 + 64B\times 4\times 10^5 = 26.3MB \ll 256MB$.

因此空间复杂度满足要求.


文章作者: Chengsx
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Chengsx !
  目录