R = lambda: map(int, input().split())
n,m = R()
L = list(R())
d = [[] for i in range(m)]
for j,i in enumerate(L):
d[i%m].append(j)
k = n//m
a = []
res = 0
j = 0
for i in range(m):
while len(d[i]) > k:
while j < i or len(d[j % m]) >= k: j += 1
ind = d[i].pop()
L[ind] += (j-i)%m
res += (j-i)%m
d[j%m].append(ind)
print(res)
print(*L)