|

General

# Author Problem Lang Verdict Time Memory Sent Judged
127234490 Practice:
Qiuly.qwq
1276F - 5 GNU C++14 Accepted 608 ms 135360 KB 2021-08-28 07:08:20 2021-08-28 07:08:20
→ Source
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair <int, int> pii;

#define fi first
#define se second
#define rez resize
#define pb push_back
#define mkp make_pair

#define Lep(i, l, r) for (int i = l; i < r; ++ i)
#define Rep(i, r, l) for (int i = r; i > l; -- i)
#define lep(i, l, r) for (int i = l; i <= r; ++ i)
#define rep(i, r, l) for (int i = r; i >= l; -- i)

const int mod = 998244353;

inline int mul (int x, int y) { return 1ll * x * y % mod; }
inline void sub (int &x, int y) { x -= y; if (x < 0) x += mod; }
inline void pls (int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline int dec (int x, int y) { x -= y; if (x < 0) x += mod; return x; }
inline int add (int x, int y) { x += y; if (x >= mod) x -= mod; return x; }
inline int modpow (int x, ll y, int res = 1) {
for (y = (y + mod - 1) % (mod - 1); y; y >>= 1, x = mul (x, x)) if (y & 1) res = mul (x, res);
return res;
}

char _c; bool _f; template <class T> inline void IN (T & x) {
x = 0, _f = 0; while (_c = getchar (), ! isdigit (_c)) if (_c == '-') _f = 1;
while (isdigit (_c)) x = x * 10 + _c - '0', _c = getchar (); if (_f) x = -x;
}

template <class T> inline void chkmin (T & x, T y) { if (x > y) x = y; }
template <class T> inline void chkmax (T & x, T y) { if (x < y) x = y; }

const int N = 2e5 + 5;
const int LogN = 20;

int n, tim, val[N], sum[N], dfn[N], seg[N], dep[N], tag[N];
char str[N];

// {{{ rmq;

int dlen, fir[N], seq[N << 1], Log[N << 1], st[N << 1][19];

inline int chkgod (int x, int y) { return (dep[x] < dep[y]) ? x : y; }
inline int lca (int x, int y) {
x = fir[x], y = fir[y];
if (x > y) swap (x, y);
int res = Log[y - x + 1];
return chkgod (st[x][res], st[y - (1 << res) + 1][res]);
}

//}}}

struct suffix_automaton {
int lst = 1, tot = 1, fa[N], nod[N], pos[N], siz[N], len[N], ch[N][26];
vector <int> to[N];

inline void insert (int c, int p) {
int now = ++ tot, pre = lst;
len[lst = now] = len[pre] + 1, pos[now] = p, nod[p] = now;
for (; pre && ! ch[pre][c]; ch[pre][c] = now, pre = fa[pre]);
if (! pre) return fa[now] = 1, void ();
int son = ch[pre][c];
if (len[son] == len[pre] + 1) return fa[now] = son, void ();
int cop = ++ tot;
memcpy (ch[cop], ch[son], sizeof ch[cop]);
len[cop] = len[pre] + 1, fa[cop] = fa[son], fa[son] = fa[now] = cop;
for (; pre && ch[pre][c] == son; ch[pre][c] = cop, pre = fa[pre]);
}

void build () { lep (i, 1, tot) to[fa[i]].pb (i); }
void dfs (int u) {
dfn[u] = ++ tim, seg[tim] = u, seq[++ dlen] = u, fir[u] = dlen;
for (int v : to[u]) {
dep[v] = dep[u] + 1, val[v] = len[v] - len[u], sum[v] = sum[u] + val[v];
dfs (v), seq[++ dlen] = u;
}
}
inline void init () {
build (), dfs (1);

Log[0] = -1;
lep (i, 1, dlen) Log[i] = Log[i >> 1] + 1;
lep (i, 1, dlen) st[i][0] = seq[i];
Lep (j, 1, LogN) lep (i, 1, dlen - (1 << j) + 1)
st[i][j] = chkgod (st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}
} s, t;

// {{{ solve

ll ans, res[N];
int rt[N], mxpos[N], mipos[N];
set <int> S[N];

inline void merge (int u, int v) {
chkmax (mxpos[u], mxpos[v]), chkmin (mipos[u], mipos[v]);

if (S[rt[u]].size () < S[rt[v]].size ()) swap (rt[u], rt[v]);
for (int t : S[rt[v]]) {
if (S[rt[u]].find (t) != S[rt[u]].end ()) continue ;

int now = seg[t], top = 1, tmp;
set <int> :: iterator it = S[rt[u]].upper_bound (t);

if (it != S[rt[u]].end ()) tmp = lca (seg[*it], now), top = dep[top] > dep[tmp] ? top : tmp;
if (it != S[rt[u]].begin ()) -- it, tmp = lca (seg[*it], now), top = dep[top] > dep[tmp] ? top : tmp;
res[rt[u]] += sum[now] - sum[top], S[rt[u]].insert (t);
}
}
void solve (int u) {
rt[u] = u, mxpos[u] = 0, mipos[u] = n + 1;
if (s.pos[u]) {
if (s.pos[u] + 2 <= n)
S[rt[u]].insert (dfn[t.nod[s.pos[u] + 2]]), res[rt[u]] = sum[t.nod[s.pos[u] + 2]];
mxpos[u] = mipos[u] = s.pos[u];
}
for (int v : s.to[u]) solve (v), merge (u, v);

if (u != 1) {
int dan = s.len[u] - s.len[s.fa[u]];
// printf ("calc in %d [%d, %d] : %d, %lld, %d, %d\n", u, s.len[s.fa[u]] + 1, s.len[u],
//	dan, 1ll * res[rt[u]] * dan, (mipos[u] < n) * dan, dan - (mxpos[u] == s.len[u]));
ans += dan + 1ll * res[rt[u]] * dan + (mipos[u] < n) * dan + dan - (mxpos[u] == s.len[u]);
}
}

// }}}

int main () {
scanf ("%s", str + 1), n = strlen (str + 1);

lep (i, 1, n) s.insert (str[i] - 'a', i);
rep (i, n, 1) t.insert (str[i] - 'a', i);

s.build (), t.init ();
solve (1), printf ("%lld\n", ans + 2);
return 0;
}
?
Time: ? ms, memory: ? KB
Verdict: ?
Input
?
Participant's output
?
Jury's answer
?
Checker comment
?
Diagnostics
?
Click to see test details