神奇的树
https://www.matiji.net/exam/dohomework/8235/2
题干
小码哥在花园中种了 颗神奇的树,花园中的第 颗树 正好是无根树 以 节点为根所形成的有根树。( 由输入给定)
小码哥想知道对于所有从 到 的 , 与花园中的多少颗树同构,即集合 的大小?( 与 同构也可统计进答案)
两棵有根树 同构当且仅当它们的大小相等,且存在一个顶点排列 使得在 中 是 的祖先当且仅当在 中 是 的祖先。
思路
注意时间限制 不是 , ,大概是 以下。
做法是树 hash,即对于一个树,将其编码成 hash。然后计算过程中是可以直接换根的,因此复杂度只有 。
树哈希的定义与计算方法
树哈希是一种将树结构编码成一个唯一哈希值的方法,它能够帮助我们快速判断两棵树是否同构。在同构的树中,节点的相对结构是相同的,但节点的编号可能不同。
树哈希的核心思想
我们通过递归地计算每个节点的哈希值 ,并结合子节点的哈希值来唯一确定一棵树的结构。
树哈希的定义包含以下几点:
- 叶子节点的哈希值是一个固定值,通常初始化为 (可以理解为基本单元)。
- 非叶子节点的哈希值由其所有子节点的哈希值通过某个函数组合得到。
- 为了避免哈希冲突,通常使用一个扰动函数 对子节点的哈希值进行处理,使哈希值更加均匀分布。
具体计算方法
我们定义一个节点 的哈希值 如下:
-
基础值 :每个节点自身初始哈希值设为 1。
-
扰动函数 :用于对子节点的哈希值进行扰动,常用形式是:
其中 (\oplus) 表示按位异或,或者可以用 等其他操作。
这个函数的作用是让相同结构但不同子树排列的节点产生不同的中间哈希值,减少冲突。 -
节点哈希值的计算公式:
对于一个节点 ,它的哈希值 等于自身基础值 加上所有子节点的扰动后哈希值之和:其中 是子节点 的哈希值经过扰动函数处理后的结果。
示例
考虑以下树结构:
自底向上计算哈希值
叶子节点(4 和 5):
叶子节点没有子节点,它们的哈希值初始化为 , , 。
节点 2(子节点为 4 和 5):
由自身的初始值和两个子节点的哈希值决定:
假设扰动函数 ,则:
所以:
节点 3(叶子节点):
。
根节点 1(子节点为 2 和 3):
由自身的初始值和两个子节点的哈希值决定:
已知 , ,并计算:
所以:
最终的哈希值:
节点 | 子节点 | |
---|---|---|
4 | 无 | 1 |
5 | 无 | 1 |
2 | 4, 5 | 77 |
3 | 无 | 1 |
1 | 2, 3 | 2433 |
现在可以计算任意有根树的哈希了,不过下面可以通过换根加速计算。
换根
因为
而如果要换到一个相连的 ,那么
注意,只有 和 的从属关系发生变化,因此其它子树的哈希不变,因此在新的树里,
而
对于其它节点, 。同时,哈希规则不变,因此它总是代表树的结构——如果两个树的结构相同,那么它们的哈希值也相同。
这样程序实际复杂度在 ,可能常数比较大。
代码
为了避免 hash 碰撞,把 可以任性地取复杂一些。
#include <vector>
#include <functional>
#include <iostream>
#include <map>
using namespace std;
using ull = unsigned long long;
constexpr ull mask = 0x1234567898765432;
ull f(ull x) {
x ^= mask;
x ^= x << 7;
x ^= x >> 13;
x ^= x << 11;
x ^= mask;
return x;
}
int main()
{
int n;
cin >> n;
vector<vector<int>> next(n, vector<int>());
for(int i = 0; i < n - 1; ++i) {
int a, b;
cin >> a >> b;
--a;
--b;
next[a].push_back(b);
next[b].push_back(a);
}
vector<ull> root_h(n, 0);
vector<ull> tmp_h(n, 1);
function<void(int,int)> dfs = [&](int cur, int from) {
for(int n: next[cur]) {
if(n != from) dfs(n, cur);
}
for(int n: next[cur]) {
if(n == from) continue;
tmp_h[cur] += f(tmp_h[n]);
}
};
dfs(0, -1);
function<void(int,int)> dfs2 = [&](int cur, int from) {
root_h[cur] = tmp_h[cur];
for(int n: next[cur]) {
if(n == from) continue;
ull next_hash = f(tmp_h[n]);
tmp_h[cur] -= next_hash;
tmp_h[n] += tmp_h[cur];
dfs2(n, cur);
tmp_h[n] -= tmp_h[cur];
tmp_h[cur] += next_hash;
}
};
dfs2(0, -1);
map<ull, int> cnt;
for(auto e: root_h) {
++cnt[e];
}
for(auto e: root_h) {
cout << cnt[e] << endl;
}
return 0;
}