题目大意


给一棵根为$1$的有根树,点$i$具有一个权值$a_i$。

定义一个点对的值$f(u,v)=max(a_u,a_v)×\vert a_u−a_v\vert$。

你需要对于每个节点$i$,计算 $ans_i=∑_{u∈subtree(i),v∈subtree(i)}f(u,v)$ ,其中$subtree(i)$表示$i$的子树。

请你输出 $⊕(ans_i \mod 2^{64})$ ,其中 $⊕$ 表示$XOR$。

解题思路


容易联想到启发式合并,顺着这个思路,思考怎么计算将一个数合并进一个集合对答案的贡献。假设要在当前集合$S$放进一个数$x$,对于小于$x$的数,贡献应该为$\sum_{u\in S,\ u\lt x}(x-u) \cdot x$;对于大于等于$x$的数,需要引进两个变量$ts$和$sum$,其中$ts = \sum_{u\in S, \ u \ge x}u^2$,$sum=\sum_{u\in S, \ u\ge x}u$,那么此时的贡献就为$ts - sum \cdot x$。(关于如何想到这一点,对于$S$中大于$x$的元素$u$,它的贡献为$u \cdot\vert u - x\vert$,与这个元素跟$x$的差值有关,所以多存一个变量记录$u \cdot u$的和,这样就可以快速算出贡献)

所以,对于每个集合和一个$x$,都需要能够快速求出这三个值$cnt = \sum_{u\in S}[u \ge x]$,$sum$,$ts$,显然线段树能实现这个功能,于是我们在启发式合并的基础上再套一个线段树合并,这题就结束了。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <bits/stdc++.h>
#define maxn 500100
using i32 = int;
#define int unsigned long long
using namespace std;
const double eps = 1e-8;
struct node {
i32 s0, s1;
int cnt, sum, ts;
}tr[30 * maxn];
int res[maxn];
i32 a[maxn], rt[maxn], idx;
vector<i32> e[maxn], s[maxn];
void pushup(int u) {
tr[u].cnt = tr[tr[u].s0].cnt + tr[tr[u].s1].cnt;
tr[u].sum = tr[tr[u].s0].sum + tr[tr[u].s1].sum;
tr[u].ts = tr[tr[u].s0].ts + tr[tr[u].s1].ts;
}

void modify(i32 &u, int l, int r, int x, int d) {
if(!u) u = ++idx;
if(l == r) {
tr[u].cnt += d;
tr[u].sum += d * x;
tr[u].ts += x * x * d;
}
else {
int mid = l + r >> 1;
if(x <= mid) modify(tr[u].s0, l, mid, x, d);
if(x > mid) modify(tr[u].s1, mid + 1, r, x, d);
pushup(u);
}
}

void merge(i32 &u, int v, int l, int r) {
if(!u || !v) {u = u + v; return ;}
if(l == r) tr[u].cnt += tr[v].cnt, tr[u].sum += tr[v].sum, tr[u].ts += tr[v].ts;
else {
int mid = l + r >> 1;
merge(tr[u].s0, tr[v].s0, l, mid);
merge(tr[u].s1, tr[v].s1, mid + 1, r);
pushup(u);
}
}

node query(int u, int l, int r, int x) {
if(!u) return {0, 0, 0, 0, 0};
node tt;
if(l >= x) {
tt.sum = tr[u].sum;
tt.ts = tr[u].ts - tr[u].sum * x;
tt.cnt = tr[u].cnt;
return tt;
}
int mid = l + r >> 1;
auto tmp = query(tr[u].s1, mid + 1, r, x);
tt = tmp;
if(mid >= x) {
tmp = query(tr[u].s0, l, mid, x);
tt.sum += tmp.sum; tt.cnt += tmp.cnt; tt.ts += tmp.ts;
}
return tt;
}

void dfs(int x, int fa) {
for (auto u : e[x]) {
if(u == fa) continue;
dfs(u, x);
res[x] = res[x] + res[u];
if(s[u].size() > s[x].size()) swap(s[x], s[u]), swap(rt[x], rt[u]);
for (auto v : s[u]) {
if(x == 2) {
int p = 1;
}
auto t = query(rt[x], 1, 1e6, v);
int s1 = t.sum, n1 = t.cnt;
int s2 = tr[rt[x]].sum - s1, n2 = tr[rt[x]].cnt - n1;
res[x] = res[x] + 2 * t.ts + v * (v * n2 - s2) * 2;
s[x].push_back(v);
}
s[u].clear();
merge(rt[x], rt[u], 1, 1e6);
}
}

void solve() {
int n; cin >> n;
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
e[u].push_back(v); e[v].push_back(u);
}
for (int i = 1; i <= n; ++i) {
cin >> a[i];
s[i].push_back(a[i]);
modify(rt[i], 1, 1e6, a[i], 1);
}
dfs(1, -1);
int ans = 0;
for (int i = 1; i <= n; ++i) {
// cout << res[i] << " \n"[i == n];
ans = ans ^ res[i];
}
cout << ans << '\n';
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
int t = 1;
while (t--) {
solve();
}
return 0;
}