General
 
 
# Author Problem Lang Verdict Time Memory Sent Judged  
131576407 Practice:
Ayushman_123
884C - 9 PyPy 3 Runtime error on test 8 280 ms 95168 KB 2021-10-11 20:10:11 2021-10-11 20:26:32
→ Source
import math
import os
import sys
from io import BytesIO, IOBase
M = 1000000007
import random
import heapq
import bisect
import time
sys.setrecursionlimit(10**5)
from functools import *
from collections import *
from itertools import *
BUFSIZE = 8192
import array
class FastIO(IOBase):
    newlines = 0
    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None
    def read(self):
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()
    def readline(self):
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()
    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable
        self.write = lambda s: self.buffer.write(s.encode("ascii"))
        self.read = lambda: self.buffer.read().decode("ascii")
        self.readline = lambda: self.buffer.readline().decode("ascii")
def print(*args, **kwargs):
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", sys.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()
if sys.version_info[0] < 3:
    sys.stdin, sys.stdout = FastIO(sys.stdin), FastIO(sys.stdout)
else:
    sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
input = lambda: sys.stdin.readline().rstrip("\r\n")
def inp(): return sys.stdin.readline().rstrip("\r\n")  # for fast input
def out(var): sys.stdout.write(str(var))  # for fast output, always take string
def lis(): return list(map(int, inp().split()))
def stringlis(): return list(map(str, inp().split()))
def sep(): return map(int, inp().split())
def strsep(): return map(str, inp().split())
def fsep(): return map(float, inp().split())
def inpu(): return int(inp())

def build(arr,a,b,st,node):
    if a==b:
        st[node] =  arr[a]
        return st[node]
    mid = (a+b)//2
    st[node] = min(build(arr,a,mid,st,2*node+1) , build(arr,mid+1,b,st,2*node+2))
    return st[node]

def getmin(arr,a,b,l,r,st,node):
    if l>b or r<a:
        return float("inf")
    if l<=a and r>=b:
        return st[node]
    mid = (a+b)//2
    return min(getmin(arr,a,mid,l,r,st,2*node+1),getmin(arr,mid+1,b,l,r,st,2*node+2))

def dfs(cur,visited,d,res):
    visited[cur] = True
    for i in d[cur]:
        if not visited[i]:
            res[0]+=1
            dfs(i,visited,d,res)
def main():
    t = 1
    #t=int(input())
    for _ in range(t):
        n = inpu()
        arr = lis()
        d = defaultdict(list)
        for i in range(n):
            d[i].append(arr[i]-1)
            d[arr[i]-1].append(i)
        visited = [False]*(n+1)
        ans=[]
        for i in range(n):
            res=[1]
            if not visited[i]:
                dfs(i,visited,d,res)
                ans.append(res[0])
        ans.sort()
        if len(ans)==1:
            print(ans[0]*ans[0])
        elif len(ans)>1:
            c = ans[-1]+ans[-2]
            res1 = c*c
            for i in range(len(ans)-2):
                res1+=(ans[i]*ans[i])
            print(res1)
        else:
            print(0)
if __name__ == '__main__':
    main()
?
Time: ? ms, memory: ? KB
Verdict: ?
Input
?
Participant's output
?
Jury's answer
?
Checker comment
?
Diagnostics
?
Click to see test details