General
 
 
# Author Problem Lang Verdict Time Memory Sent Judged  
256192789 Practice:
luogu_bot3
1172E - 24 C++20 (GCC 13-64) Accepted 4124 ms 80864 KB 2024-04-12 06:32:35 2024-04-12 06:32:35
→ Source
// LUOGU_RID: 155433556
#include<bits/stdc++.h>
#define ci const int
#define ll long long
#define ls ch[x][0]
#define rs ch[x][1]
using namespace std;
ci N=4e5+5;
int n,m,fa[N];
vector<int>g[N];
ll sqr(ll x){
	return x*x;
}
int col[N];
ll ans[N];
vector<pair<int,int> >vec[N];
bool tp[N];
struct LCT{
	int ch[N][2],bz[N],f[N],v[N],siz[N];
	ll siz2[N],sum;
	void upd(ci x){
		siz[x]=siz[ls]+v[x]+siz[rs];
	}
	bool gt(ci x){
		return ch[f[x]][1]==x;
	}
	bool nrt(ci x){
		return ch[f[x]][0]==x||ch[f[x]][1]==x;
	}
	void rev(ci x){
		swap(ls,rs),bz[x]^=1;
	}
	void pushdown(ci x){
		if(bz[x])rev(ls),rev(rs),bz[x]=0;
	}
	void Rotate(ci x){
		ci y=f[x],z=f[y],d=gt(x);
		if(nrt(y))ch[z][gt(y)]=x;
		ch[y][d]=ch[x][d^1];
		if(ch[y][d])f[ch[y][d]]=y;
		ch[x][d^1]=y,f[y]=x,f[x]=z,
		upd(y),upd(x);
	}
	int st[N];
	void Splay(int x){
		int top=0,k=x;st[++top]=x;
		while(nrt(k))st[++top]=k=f[k];
		while(top)pushdown(st[top--]);
		for(int y=f[x];nrt(x);Rotate(x),y=f[x])
			if(nrt(y))Rotate(gt(x)^gt(y)?x:y);
	}
	void Access(int x){
		for(int y=0;x;x=f[y=x])
			Splay(x),v[x]+=siz[rs],siz2[x]+=(ll)siz[rs]*siz[rs],
			rs=y,v[x]-=siz[rs],siz2[x]-=(ll)siz[rs]*siz[rs],
			upd(x);
	}
	int qry(ci x){
		if(!tp[x])return 0;
		Access(x),Splay(x);
		int y=x;
		while(ch[y][0])pushdown(y),y=ch[y][0];
		Splay(y);
		return siz[ch[y][1]];
	}
	void Link(ci x){
		sum-=sqr(qry(fa[x])),
		Access(x),Splay(x),Access(fa[x]),Splay(fa[x]),
		sum-=siz2[x],
		f[x]=fa[x],v[fa[x]]+=siz[x],siz2[fa[x]]+=(ll)siz[x]*siz[x],upd(fa[x]),
		tp[x]=1,
		sum+=sqr(qry(x));
	}
	void Cut(ci x){
		sum-=sqr(qry(x));
		Access(x),Splay(x),
		tp[x]=0,
		f[ls]=0,ls=0,
		upd(x);
		sum+=siz2[x]+sqr(qry(fa[x]));
	}
}A;
void dfs(ci x){
	for(int y:g[x])
		if(y!=fa[x])
			fa[y]=x,dfs(y);
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;++i)
		scanf("%d",&col[i]),
		vec[col[i]].push_back(make_pair(0,i));
	for(int i=1,x,y;i<n;++i)
		scanf("%d%d",&x,&y),
		g[x].push_back(y),
		g[y].push_back(x);
	g[n+1].push_back(1),dfs(n+1);
	for(int i=1;i<=n;++i)A.v[i]=1,A.upd(i);
	for(int t=1,x,c;t<=m;++t){
		scanf("%d%d",&x,&c),
		vec[col[x]].push_back(make_pair(t,x)),
		vec[c].push_back(make_pair(t,x)),
		col[x]=c;
	}
	for(int i=1;i<=n;++i)A.Link(i),tp[i]=1;
	ll tmp=A.sum;
	for(int c=1;c<=n;++c){
		if(vec[c].empty())continue;
		if(A.sum!=tmp)return -1;
		for(auto tmp:vec[c]){
			ll las=A.sum;
			if(tp[tmp.second])A.Cut(tmp.second);
			else A.Link(tmp.second);
			ll delta=A.sum-las;
			ans[tmp.first]-=delta;
		}
		for(auto tmp:vec[c])
			if(!tp[tmp.second])
				A.Link(tmp.second);
	}
	for(int i=0;i<=m;++i)ans[i]+=ans[i-1],printf("%lld\n",ans[i]);
}
?
Time: ? ms, memory: ? KB
Verdict: ?
Input
?
Participant's output
?
Jury's answer
?
Checker comment
?
Diagnostics
?
Click to see test details