n,m = map(int, input().split())
A = list(map(int, input().split()))
C = [[] for i in range(m)]
k = n//m
for i, a in enumerate(A):
C[a%m].append(i)
ans = 0
R = []
for i in range(2*m):
r = i%m
while len(C[r]) > k:
j = C[r].pop()
R.append((j, i))
while len(C[r]) < k and R:
j, pi = R.pop()
C[r].append(j)
A[j] += i-pi
ans += i-pi
print(ans)
print(*A)