#include<vector>
#include<cstdio>
#define pb push_back
#define ll long long
using namespace std;
const int N=1e6;
vector<int>cx[N],cy[N],cz[N];
int fa[N],fir[N],nxt[N],to[N],cnt=0,n,m,col[N],ch[N][2],f[N],sz[N],sz1[N];
ll sz2[N],dt[N],ans=0;
ll P(int a)
{
return 1ll*a*a;
}
void add(int a,int b)
{
nxt[++cnt]=fir[a];
to[cnt]=b;
fir[a]=cnt;
}
void adda(int a,int b,int c,int d)
{
cx[a].pb(b);
cy[a].pb(c);
cz[a].pb(d);
}
void DFS(int u)
{//cout<<"*"<<u<<" "<<fa[u]<<"\n";
for(int i=fir[u];i;i=nxt[i]) if(fa[u]!=to[i]) fa[to[i]]=u,DFS(to[i]);
}
bool isrt(int x)
{
return (ch[f[x]][0]!=x)&(ch[f[x]][1]!=x);
}
void pushup(int x)
{
sz[x]=sz1[x]+sz[ch[x][0]]+sz[ch[x][1]]+1;
}
bool son(int x)
{
return x==ch[f[x]][1];
}
void rotate(int x)
{
int a=f[x],b=f[a],c=son(x),d=son(a),e=ch[x][!c];
if(!isrt(a)) ch[b][d]=x;ch[x][!c]=a;ch[a][c]=e;
if(e) f[e]=a;f[a]=x;f[x]=b;
pushup(a);
}
void splay(int x)
{
int b,c;
while(!isrt(x))
{
b=f[x];c=f[b];
if(!isrt(b)) {
if(son(b)==son(x)) rotate(b);
else rotate(x);
}
rotate(x);
}
pushup(x);
}
void access(int x)
{
for(int y=0;x;y=x,x=f[x])
{
splay(x);
sz1[x]+=sz[ch[x][1]]-sz[y];
sz2[x]+=P(sz[ch[x][1]])-P(sz[y]);
ch[x][1]=y;
pushup(x);
}
}
int getroot(int x)
{
access(x);splay(x);
while(ch[x][0]) x=ch[x][0];
splay(x);
return x;
}
void link(int u)
{
int v=fa[u];
splay(u);
ans-=P(sz[ch[u][1]]);
ans-=sz2[u];
int w=getroot(v);
access(u);splay(w);
ans-=P(sz[ch[w][1]]);
f[u]=v;
splay(v);
sz1[v]+=sz[u];sz2[v]+= P(sz[u]);
pushup(v);
access(u);splay(w);
ans+=P(sz[ch[w][1]]);
}
void cut(int u)
{
int v=fa[u];
access(u);
ans+=sz2[u];
int w=getroot(v);
access(u);splay(w);
ans-=P(sz[ch[w][1]]);
splay(u);
ch[u][0]=f[ch[u][0]]=0;
pushup(u);splay(w);
ans+=P(sz[ch[w][1]]);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&col[i]),adda(col[i],i,1,0);
for(int i=1,u,v;i<n;i++) {
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
for(int i=1;i<=n+1;i++) sz[i]=1;
fa[1]=n+1;
DFS(1);
for(int i=1,x,y;i<=m;i++) {
scanf("%d%d",&x,&y);
adda(col[x],x,-1,i);
adda(col[x]=y,x,1,i);
}
ll las=1ll*n*n;
for(int i=1;i<=n;i++) link(i);
for(int i=1;i<=n;i++) {
for(int j=0;j<cx[i].size();j++) {
int x=cx[i][j],y=cy[i][j],z=cz[i][j];
if(y==-1) link(x);
else cut(x);
dt[z]-=ans-las;
las=ans;
}
for(int j=cx[i].size()-1;j>=0;j--){
int x=cx[i][j],y=cy[i][j],z=cz[i][j];
if(y==-1) cut(x);
else link(x);
}
las=ans;
}
printf("%I64d\n",ans=dt[0]);
for(int i=1;i<=m;i++) printf("%I64d\n",ans+=dt[i]);
return 0;
}