楼上两篇题解写的有一点点复杂,有map还写了离散化……
差分固然是一种理解方式,但其实有一种更好的理解方法和更简洁的代码。
那么现在我就来讲一讲
题意简述
文字语言:求树上最大权值随祖孙关系不降的点集大小
数学语言:求 \(|S_{max}|\) 使得 \(\forall{i,j(ancestor\ of \ i)\in S}, w_i\leq w_j\)
为了方便描述,我们定义这种集合为“树上LIS”。
题解
考虑采用数学归纳法
类似处理序列LIS问题,对于每一个点 \(u\) 使用multiset维护一个集合 \(f_u\) 满足以下性质
\(f_{u,i}\) 表示在 \(u\) 的子树中选择 \(i\) 个点组成的所有树上LIS中,级别值 \(w\) 最小值最大的那一个。
以 \(u\) 为根节点的 \(ans_u=|f_u|\)(\(|f_u|\)表示集合 \(f_u\) 的大小)(该性质可由上述性质发现)
对于任意一个叶子节点 \(u\), \(f_u\)显然只含有 \(w_u\),满足树上LIS性质。
再考虑不是叶子节点的 \(u\)
假设点 \(u\) 的所有孩子 \(v\) 的 \(f_v\) 已经满足求出并满足上述性质,我们应该如何求出 \(u\) 的 \(f_u\) 呢?
首先,显然 \(u\) 的所有孩子不会相互影响,要从以 \(u\) 为根节点的子树(除 \(u\) )中选出大小为 \(i\) 的树上LIS,可以直接贪心地选所有孩子集合中最大的 \(i\) 个,于是只需将全部 \(f_v\) 取并集并排序即可,于是可以直接将孩子们的 \(f_v\) 集合全部启发式合并丢入 \(f_u\) 的multiset ,记 \(S=\bigcup_{v\in u.son}f_v\)
现在我们考虑将 \(u\) 加入 \(S\) 集合并使集合满足性质
我们直接在multiset上二分出第一个 \(i\) 满足 \(f_{u,i}\geq w_u\) 那么我们将 \(u\) 接在 \(i\) 前显然是最优方案,此时 \(f_{u,i-1}\) 就可以被 \(w_u\) 替换,那么现在的集合就是我们要求的 \(f_u\),并且满足树上LIS性质。
按照这样的方式在树上dfs即可求出 \(f_1\),此时答案即为 \(|f_1|\)。
复杂度证明
该算法的复杂度为 \(O(nlog^2n)\)
考虑同样采用数学归纳法
记 \(T_u\) 表示处理出 \(f_u\) 的时间复杂度,\(S_u\)表示 \(u\) 的子树大小
我们需要证明 \(T_u=S_ulog^2S_u\)
对于任意一个叶子节点 \(u\),\(S_u=1\),此时只需在multiset中插入 \(w_u\) 复杂度为 \(O(1)\),满足\(T_i=S_ulog^2S_u\)
再考虑不是叶子节点的 \(u\)
假设点 \(u\) 的所有孩子 \(v\) 的 \(T_v=S_vlog^2S_v\)
那么 \(T_i=\sum_{v\in u.son}T_v+T_{merge}\)
因为子孙们包含的节点个数\(\sum_{v\in u.son}S_v+1=S_u\)
所以\(\sum_{v\in u.son}T_v\leq S_ulog^2S_u\)
启发式合并的复杂度为 \(S_ulogS_u\),使用multiset维护加一个log,\(T_{merge}=S_ulog^2S_u\)
所以\(T_u\)与 \(S_ulog^2S_u\) 同阶
证毕。
代码
#includeusing namespace std;const int N = 2e5 + 5;multiset f[N];multiset ::iterator it;int n, w[N], ans;int h[N], to[N], nxt[N], t;bool comp(int x, int y) { return w[x] < w[y]; }void add(int u, int v) { to[++t] = v, nxt[t] = h[u], h[u] = t; }void merge(int u, int v) { if(f[u].size() < f[v].size()) swap(f[u], f[v]); for(it = f[v].begin(); it != f[v].end(); ++it) f[u].insert(*it);}void dfs(int u) { for(int i = h[u]; i; i = nxt[i]) dfs(to[i]), merge(u, to[i]); f[u].insert(w[u]); it = f[u].lower_bound(w[u]); if(it != f[u].begin()) f[u].erase(--it);}int main() { scanf("%d", &n); for(int i = 1; i <= n; ++i) scanf("%d", &w[i]); for(int i = 2; i <= n; ++i) { int f; scanf("%d", &f); add(f, i); } dfs(1); printf("%d", f[1].size());}