1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
| #include <bits/stdc++.h> #define maxn 500100 using i32 = int; #define int unsigned long long using namespace std; const double eps = 1e-8; struct node { i32 s0, s1; int cnt, sum, ts; }tr[30 * maxn]; int res[maxn]; i32 a[maxn], rt[maxn], idx; vector<i32> e[maxn], s[maxn]; void pushup(int u) { tr[u].cnt = tr[tr[u].s0].cnt + tr[tr[u].s1].cnt; tr[u].sum = tr[tr[u].s0].sum + tr[tr[u].s1].sum; tr[u].ts = tr[tr[u].s0].ts + tr[tr[u].s1].ts; }
void modify(i32 &u, int l, int r, int x, int d) { if(!u) u = ++idx; if(l == r) { tr[u].cnt += d; tr[u].sum += d * x; tr[u].ts += x * x * d; } else { int mid = l + r >> 1; if(x <= mid) modify(tr[u].s0, l, mid, x, d); if(x > mid) modify(tr[u].s1, mid + 1, r, x, d); pushup(u); } }
void merge(i32 &u, int v, int l, int r) { if(!u || !v) {u = u + v; return ;} if(l == r) tr[u].cnt += tr[v].cnt, tr[u].sum += tr[v].sum, tr[u].ts += tr[v].ts; else { int mid = l + r >> 1; merge(tr[u].s0, tr[v].s0, l, mid); merge(tr[u].s1, tr[v].s1, mid + 1, r); pushup(u); } }
node query(int u, int l, int r, int x) { if(!u) return {0, 0, 0, 0, 0}; node tt; if(l >= x) { tt.sum = tr[u].sum; tt.ts = tr[u].ts - tr[u].sum * x; tt.cnt = tr[u].cnt; return tt; } int mid = l + r >> 1; auto tmp = query(tr[u].s1, mid + 1, r, x); tt = tmp; if(mid >= x) { tmp = query(tr[u].s0, l, mid, x); tt.sum += tmp.sum; tt.cnt += tmp.cnt; tt.ts += tmp.ts; } return tt; }
void dfs(int x, int fa) { for (auto u : e[x]) { if(u == fa) continue; dfs(u, x); res[x] = res[x] + res[u]; if(s[u].size() > s[x].size()) swap(s[x], s[u]), swap(rt[x], rt[u]); for (auto v : s[u]) { if(x == 2) { int p = 1; } auto t = query(rt[x], 1, 1e6, v); int s1 = t.sum, n1 = t.cnt; int s2 = tr[rt[x]].sum - s1, n2 = tr[rt[x]].cnt - n1; res[x] = res[x] + 2 * t.ts + v * (v * n2 - s2) * 2; s[x].push_back(v); } s[u].clear(); merge(rt[x], rt[u], 1, 1e6); } }
void solve() { int n; cin >> n; for (int i = 1; i < n; ++i) { int u, v; cin >> u >> v; e[u].push_back(v); e[v].push_back(u); } for (int i = 1; i <= n; ++i) { cin >> a[i]; s[i].push_back(a[i]); modify(rt[i], 1, 1e6, a[i], 1); } dfs(1, -1); int ans = 0; for (int i = 1; i <= n; ++i) { ans = ans ^ res[i]; } cout << ans << '\n'; }
signed main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); int t = 1; while (t--) { solve(); } return 0; }
|