### brownfox2k6's blog

By brownfox2k6, history, 2 months ago,

#### Problem source

104168D2 - Nested Sum (Hard Version)

#### Statement

Given an array of $n$ positive integers $a_{1}, a_{2},...,a_{n}$, find the value of $\sum_{i=1}^{n}\sum_{j=i+1}^{n}\sum_{k=j+1}^{n}a_{i}a_{j}a_{k}$ modulo $10^{9}+7$

#### Input

The first line of input contains an integer $t$ ($1 \le t \le 10^{4}$).

The first line of each test case contains an integer $n$ ($1 \le t \le 10^{5}$), the size of the array.

The second line of each test case contains $n$ integers $a_{1}, a_{2},...,a_{n}$ ($1 \le t \le 10^{9}$), the elements of the array.

The sum of n over all test cases doesn't exceed $5\cdot 10^{5}$.

#### Output

For each test case output one line containing one integer, the sum described in the problem modulo $10^{9}+7$

Input
Output

#### What I've done

I have tried to solve this problem for an hour but when I submit it, many testcase is WA:

My verdict

I don't know why my code failed, can you find out my mistake? This is my code.

My code

• +4

 » 2 months ago, # |   +2 sorry cnt help u bro
 » 2 months ago, # |   +4 The problem is on line 38 with pre1[n] - pre1[i+1]. This value might be negative. You want all numbers to lie between 0 and MOD-1. You can modify add function to handle negative values which is giving AC.But, I would recommend you to use modint templates because it is easier and cleaner. Code with modint template#include using namespace std; #define int long long #define fastio ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0) #define multitest int _T; cin >> _T; while (_T--) const int MOD = 1000000007; template struct modint { static int _pow(int n, int k) { int r = 1; for (; k > 0; k >>= 1, n = (n*n)%M) if (k&1) r = (r*n)%M; return r; } int v; modint(int n = 0) : v(n%M) { v += (M&(0-(v<0))); } friend string to_string(const modint n) { return to_string(n.v); } friend istream& operator>>(istream& i, modint& n) { return i >> n.v; } friend ostream& operator<<(ostream& o, const modint n) { return o << n.v; } template explicit operator T() { return T(v); } friend bool operator==(const modint n, const modint m) { return n.v == m.v; } friend bool operator!=(const modint n, const modint m) { return n.v != m.v; } friend bool operator<(const modint n, const modint m) { return n.v < m.v; } friend bool operator<=(const modint n, const modint m) { return n.v <= m.v; } friend bool operator>(const modint n, const modint m) { return n.v > m.v; } friend bool operator>=(const modint n, const modint m) { return n.v >= m.v; } modint& operator+=(const modint n) { v += n.v; v -= (M&(0-(v>=M))); return *this; } modint& operator-=(const modint n) { v -= n.v; v += (M&(0-(v<0))); return *this; } modint& operator*=(const modint n) { v = (v*n.v)%M; return *this; } modint& operator/=(const modint n) { v = (v*_pow(n.v, M-2))%M; return *this; } friend modint operator+(const modint n, const modint m) { return modint(n) += m; } friend modint operator-(const modint n, const modint m) { return modint(n) -= m; } friend modint operator*(const modint n, const modint m) { return modint(n) *= m; } friend modint operator/(const modint n, const modint m) { return modint(n) /= m; } modint& operator++() { return *this += 1; } modint& operator--() { return *this -= 1; } modint operator++(signed) { modint t = *this; return *this += 1, t; } modint operator--(signed) { modint t = *this; return *this -= 1, t; } modint operator+() { return *this; } modint operator-() { return modint(0) -= *this; } // O(logk) modular exponentiation modint pow(const int k) const { return k < 0 ? _pow(v, M-1-(-k%(M-1))) : _pow(v, k); } modint inv() const { return _pow(v, M-2); } }; // in case of error remove ++ to += 1 using modi = modint<(int)1e9+7>; void solve() { int n; cin >> n; int a[n]; modi pre1[n+1]; pre1[0] = 0; for (int i = 0; i < n; ++i) { cin >> a[i]; pre1[i+1] = a[i]+pre1[i]; } modi pre2[n-1]; pre2[0] = 0; for (int i = 1; i < n-1; ++i) pre2[i] = pre2[i-1] + a[i]*(pre1[n] - pre1[i+1]); modi ans = 0; for (int i = 0; i < n-2; ++i) ans += a[i]*(pre2[n-2] - pre2[i]); cout << ans << endl; } signed main() { fastio; multitest solve(); } 
•  » » 2 months ago, # ^ |   0 Is that why when I transfer my code to Python, it causes Runtime error? VerdictTest 1: OK, 0 point(s) Group G1: 0.0 point(s) Test 2: WRONG_ANSWER Test 3: WRONG_ANSWER Test 4: WRONG_ANSWER Test 5: WRONG_ANSWER Test 6: WRONG_ANSWER Test 7: WRONG_ANSWER Test 8: WRONG_ANSWER Test 9: WRONG_ANSWER Test 10: WRONG_ANSWER Test 11: WRONG_ANSWER Test 12: WRONG_ANSWER Test 13: WRONG_ANSWER Test 14: WRONG_ANSWER Test 15: WRONG_ANSWER Test 16: WRONG_ANSWER Test 17: WRONG_ANSWER Test 18: WRONG_ANSWER Test 19: WRONG_ANSWER Test 20: WRONG_ANSWER Test 21: RUNTIME_ERROR Test 22: WRONG_ANSWER Test 23: RUNTIME_ERROR Test 24: WRONG_ANSWER Test 25: WRONG_ANSWER Test 26: WRONG_ANSWER Test 27: RUNTIME_ERROR Test 28: WRONG_ANSWER Test 29: WRONG_ANSWER Test 30: RUNTIME_ERROR Test 31: RUNTIME_ERROR Group G2: 0.0 point(s) Test 32: WRONG_ANSWER Test 33: WRONG_ANSWER Test 34: WRONG_ANSWER Test 35: WRONG_ANSWER Test 36: WRONG_ANSWER Test 37: WRONG_ANSWER Test 38: WRONG_ANSWER Test 39: WRONG_ANSWER Test 40: WRONG_ANSWER Test 41: WRONG_ANSWER Test 42: WRONG_ANSWER Test 43: WRONG_ANSWER Test 44: WRONG_ANSWER Test 45: WRONG_ANSWER Test 46: WRONG_ANSWER Test 47: WRONG_ANSWER Test 48: WRONG_ANSWER Test 49: WRONG_ANSWER Test 50: WRONG_ANSWER Test 51: WRONG_ANSWER Test 52: WRONG_ANSWER Test 53: WRONG_ANSWER Test 54: WRONG_ANSWER Test 55: WRONG_ANSWER Test 56: WRONG_ANSWER Test 57: WRONG_ANSWER Test 58: WRONG_ANSWER Test 59: WRONG_ANSWER Test 60: WRONG_ANSWER Test 61: WRONG_ANSWER Test 62: WRONG_ANSWER Test 63: WRONG_ANSWER Test 64: WRONG_ANSWER Test 65: WRONG_ANSWER Test 66: WRONG_ANSWER Test 67: WRONG_ANSWER Test 68: WRONG_ANSWER Test 69: WRONG_ANSWER Test 70: WRONG_ANSWER Test 71: WRONG_ANSWER Test 72: RUNTIME_ERROR Test 73: RUNTIME_ERROR Test 74: WRONG_ANSWER Test 75: RUNTIME_ERROR Test 76: WRONG_ANSWER Test 77: WRONG_ANSWER Test 78: RUNTIME_ERROR Test 79: WRONG_ANSWER Test 80: WRONG_ANSWER Test 81: WRONG_ANSWER Test 82: RUNTIME_ERROR Test 83: RUNTIME_ERROR Test 84: RUNTIME_ERROR Test 85: RUNTIME_ERROR Test 86: RUNTIME_ERROR Test 87: RUNTIME_ERROR Test 88: RUNTIME_ERROR Test 89: RUNTIME_ERROR Test 90: RUNTIME_ERROR Test 91: RUNTIME_ERROR Test 92: RUNTIME_ERROR Test 93: RUNTIME_ERROR Test 94: RUNTIME_ERROR Test 95: RUNTIME_ERROR Test 96: RUNTIME_ERROR Test 97: RUNTIME_ERROR Test 98: RUNTIME_ERROR Test 99: RUNTIME_ERROR Test 100: RUNTIME_ERROR Test 101: RUNTIME_ERROR = Points: 0.0 
•  » » 2 months ago, # ^ | ← Rev. 2 →   0 but, pre1 is my prefix sum, why pre1[n] - pre1[i+1] can be negative? (pre1[n] is the last element in the pre1 array)
•  » » » 2 months ago, # ^ |   0 Because you are taking it modulo a number. Let's say mod was 5 and the array was [3,4]. Then pre1 would be [0,3,(3+4)%5=2]. Then pre[2]-pre[1] will be -1.
•  » » » » 2 months ago, # ^ |   +3 woah, thank you so much ❤️, I have understand it.
•  » » 2 months ago, # ^ |   0 as you see my code on the post, how do I implement the mul() and add() correctly?
•  » » » 2 months ago, # ^ |   0 To implement them correctly just use long long (64-bit) type. Note that in mul function you first multiply two ints and then get the result modulo $10^9+7$. $10^9 \cdot 10^9 = 10^{18}$ — doesn't fit in int (32-bit) type so the result of mul function may be wrong due to overflow. If you use 64-bit type you will get a correct result since any product will be $\le (10^9+7-1)^2 < 2^{63}-1$ (max value for long long type)
 » 2 months ago, # | ← Rev. 2 →   +19 Interesting fact: Let $A = \sum_{i=1}^n x_i \\B = \sum_{i=1}^n x_i^2\\C = \sum_{i=1}^n x_i^3$Then the answer is just $\frac{A^3 - 3(AB - C) - C}{6}$
•  » » 2 months ago, # ^ |   +1 wow... I spent 2 hours thinking the way to implement the prefix sum and then you give me a simple equation. I realized I have a lot of things to improve. Thank you so much ❤️
•  » » 2 months ago, # ^ |   +1 Can you explain how you got this result.Thanks
•  » » » 2 months ago, # ^ |   0 Well, in general, you only need to open the brackets:)
•  » » 2 months ago, # ^ |   0 How did you derive the formula?
•  » » » 2 months ago, # ^ | ← Rev. 4 →   +1 Okay. First, there is a wonderful theorem (https://brilliant.org/wiki/symmetric-polynomials-definition/). It is not quite what you need — but it gives an understanding of how the problem should be solved. (You can search for better articles, I found it in English). Now about how to come up with this solution. Let's look at the expression $A = (\sum_{i=1}^n x_i)^3$Obviously, there are all the summands of the condition, but there are also "extra" summands. Specifically, by opening the parentheses we get $A = (\sum_{i=1}^n x_i)^3 = 6*(sum \; we \;need \;to\; calculate) + \sum_{i=1}^n x_i^3 + 3*\sum_{i=1}^n \sum_{j=i+1}^n x_i^2 x_j$Let's now calculate this sum $T = 3*\sum_{i=1}^n \sum_{j=i+1}^n x_i^2 x_j$. Again by a similar trick we get that $B = 3*\sum_{i=1}^n \sum_{j=i+1}^n x_i^2 x_j = 3*(\sum_{i=1}^n x_i^2)*( \sum_{i=1}^n x_i) - 3*\sum_{i=1}^n x_i^3$(subtract from everything "extra")Substituting in the first equality we get $A = (\sum_{i=1}^n x_i)^3 = 6*(sum \; we \;need \;to\; calculate) + \sum_{i=1}^n x_i^3 + 3*(\sum_{i=1}^n x_i^2)*( \sum_{i=1}^n x_i) - 3*\sum_{i=1}^n x_i^3$So we find now $sum \; we \;need \;to\; calculate = \frac{(\sum_{i=1}^n x_i)^3 - \sum_{i=1}^n x_i^3 - 3*(\sum_{i=1}^n x_i^2* \sum_{i=1}^n x_i - \sum_{i=1}^n x_i^3)}{6}$So we need to calculate only $\sum_{i=1}^n x_i \;\;\;B = \sum_{i=1}^n x_i^2\;\;\;C=\sum_{i=1}^n x_i^3$and the answer is just $\frac{A^3 - C - 3(AB-C)}{6}$Of course here we will not divide by 6, but multiply by $(1/6) \;\;mod \;(1e9+7)$
 » 2 months ago, # | ← Rev. 2 →   0 Here is an alternate approach to solve this problem using dynamic programmingConsider $dp(i, j)$ as sum of product of combinations of $j$ elements using the first $i$ elements of the array $\displaystyle dp(i, j) = dp(i - 1, j) + dp(i - 1, j - 1) * a_i$The ans to this problem will be $dp(n, 3)$ Time Complexity = $O(n.k)$ where $k$ is the maximum length of a combinationHere $k$ is 3, TC will be simply $O(3n) \approx O(n)$
 » 3 weeks ago, # | ← Rev. 3 →   0 This code is easy logic to understand and gets accepted :3NOTE: it's clear that add mul is just using to take %mod.  int ans = 0, suf_sum = 0, sufP = 0; for (int i = n — 1; i >= 0; --i) { // ans += a[i]*suf pairs ans = add(ans, mul(a[i], sufP)); // suf pairs += a[i]*suffix sufP = add(sufP, mul(a[i], suf_sum)); // suffix = summation of passed elements suf_sum = add(suf_sum, a[i]); }