?
# | Author | Problem | Lang | Verdict | Time | Memory | Sent | Judged | |
---|---|---|---|---|---|---|---|---|---|
256192789 |
Practice: luogu_bot3 |
1172E - 24 | C++20 (GCC 13-64) | Accepted | 4124 ms | 80864 KB | 2024-04-12 06:32:35 | 2024-04-12 06:32:35 |
// LUOGU_RID: 155433556 #include<bits/stdc++.h> #define ci const int #define ll long long #define ls ch[x][0] #define rs ch[x][1] using namespace std; ci N=4e5+5; int n,m,fa[N]; vector<int>g[N]; ll sqr(ll x){ return x*x; } int col[N]; ll ans[N]; vector<pair<int,int> >vec[N]; bool tp[N]; struct LCT{ int ch[N][2],bz[N],f[N],v[N],siz[N]; ll siz2[N],sum; void upd(ci x){ siz[x]=siz[ls]+v[x]+siz[rs]; } bool gt(ci x){ return ch[f[x]][1]==x; } bool nrt(ci x){ return ch[f[x]][0]==x||ch[f[x]][1]==x; } void rev(ci x){ swap(ls,rs),bz[x]^=1; } void pushdown(ci x){ if(bz[x])rev(ls),rev(rs),bz[x]=0; } void Rotate(ci x){ ci y=f[x],z=f[y],d=gt(x); if(nrt(y))ch[z][gt(y)]=x; ch[y][d]=ch[x][d^1]; if(ch[y][d])f[ch[y][d]]=y; ch[x][d^1]=y,f[y]=x,f[x]=z, upd(y),upd(x); } int st[N]; void Splay(int x){ int top=0,k=x;st[++top]=x; while(nrt(k))st[++top]=k=f[k]; while(top)pushdown(st[top--]); for(int y=f[x];nrt(x);Rotate(x),y=f[x]) if(nrt(y))Rotate(gt(x)^gt(y)?x:y); } void Access(int x){ for(int y=0;x;x=f[y=x]) Splay(x),v[x]+=siz[rs],siz2[x]+=(ll)siz[rs]*siz[rs], rs=y,v[x]-=siz[rs],siz2[x]-=(ll)siz[rs]*siz[rs], upd(x); } int qry(ci x){ if(!tp[x])return 0; Access(x),Splay(x); int y=x; while(ch[y][0])pushdown(y),y=ch[y][0]; Splay(y); return siz[ch[y][1]]; } void Link(ci x){ sum-=sqr(qry(fa[x])), Access(x),Splay(x),Access(fa[x]),Splay(fa[x]), sum-=siz2[x], f[x]=fa[x],v[fa[x]]+=siz[x],siz2[fa[x]]+=(ll)siz[x]*siz[x],upd(fa[x]), tp[x]=1, sum+=sqr(qry(x)); } void Cut(ci x){ sum-=sqr(qry(x)); Access(x),Splay(x), tp[x]=0, f[ls]=0,ls=0, upd(x); sum+=siz2[x]+sqr(qry(fa[x])); } }A; void dfs(ci x){ for(int y:g[x]) if(y!=fa[x]) fa[y]=x,dfs(y); } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;++i) scanf("%d",&col[i]), vec[col[i]].push_back(make_pair(0,i)); for(int i=1,x,y;i<n;++i) scanf("%d%d",&x,&y), g[x].push_back(y), g[y].push_back(x); g[n+1].push_back(1),dfs(n+1); for(int i=1;i<=n;++i)A.v[i]=1,A.upd(i); for(int t=1,x,c;t<=m;++t){ scanf("%d%d",&x,&c), vec[col[x]].push_back(make_pair(t,x)), vec[c].push_back(make_pair(t,x)), col[x]=c; } for(int i=1;i<=n;++i)A.Link(i),tp[i]=1; ll tmp=A.sum; for(int c=1;c<=n;++c){ if(vec[c].empty())continue; if(A.sum!=tmp)return -1; for(auto tmp:vec[c]){ ll las=A.sum; if(tp[tmp.second])A.Cut(tmp.second); else A.Link(tmp.second); ll delta=A.sum-las; ans[tmp.first]-=delta; } for(auto tmp:vec[c]) if(!tp[tmp.second]) A.Link(tmp.second); } for(int i=0;i<=m;++i)ans[i]+=ans[i-1],printf("%lld\n",ans[i]); }
?
?
?
?