Having trouble solving the problem Jewel-eating monster

Revision en1, by cercatrova, 2023-06-18 12:57:58

In the problem Jewel-eating Monsters if the traveler has x coin in the evening and drops one coin in the pond at midnight then in the morning they will have (x-1)*a coins. With the total coin if they buy diamonds where each costs c coins, remaining coins should be total_coins % c.

Applying the same logic if the traveler repeats the action for n nights,

Total_coins: 

(x-1)*a + ((x-1)*a-1)*a + (((x-1)*a-1)*a-1)*a + ...... + upto n nights
= x*a - a + x*a^2 - a^2 - a + .......+ x*a^n - a^n - a^(n-1) - .......... - a
=x*(a + a^2 + a^3 + ..... + a^n) - (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n)

Now for (a + a^2+ a^3 + ..... + a^n) = a*(a^n - 1)/(a-1)

For (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n),
Suppose,
Sn = n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n
Multiply a with both side,
a*Sn = n*a^2 + (n-1)*a^3 + .......... + 2*a^n + a^(n+1)

Now,
Sn - a*Sn = (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n) - (n*a^2 + (n-1)*a^3 + .......... + 2*a^n + a^(n+1))
=n*a - a^2 - a^3 - ........... - a^n - a^(n+1)
=(n+1)*a - a - a^2 - a^3 - ........... - a^n - a^(n+1)
=(n+1)*a - (a + a^2 + a^3 + ........... + a^n) - a^(n+1)
=(n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1)

Or, Sn(1-a) = (n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1)
Or, Sn = -((n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1))/(a-1)

Hence, Total_coins=x * a*(a^n - 1)/(a-1) - Sn

And remaining coins after buying diamonds = Total_coins % price_of_a_single_diamond


I tried to implement the same logic in my code. It is giving correct result for the first test case which is 357 but the rests are not matching. What mistakes am I making in the code or the math?
My code:

#include<bits/stdc++.h>
#define fast ios::sync_with_stdio(false)
using namespace std;

typedef long long ll;
ll mod_ex(ll a,ll b, ll mod){
    ll res=1;
    while(b){
        if(b%2){
            res=(res*a)%mod;
        }
        a=(a*a)%mod;
        b/=2;
    }
    return res;
}
int main()
{
    fast;
    ll x,a,n,c;
    while(true){
        cin>>x>>a>>n>>c;
        if(x+a+n+c==0)break;

        ll p=((a%c)*(mod_ex(a,n,c)-1))%c;
        // according to fermat's little theorem, (a/b)%c = (a * b^(c-2))%c when c is prime
        ll q=(p*mod_ex(a-1,c-2,c))%c;

        ll xq=(x*q)%c;
        ll yq=-(((n+1)*a-q-mod_ex(a,n+1,c))%c*mod_ex(a-1,c-2,c))%c;
        ll ans=(xq-yq)%c;
        cout<<ans<<endl;
    }
    return 0;
}

I am quite beginner in competitive programming. Any constructive criticism is appreciated.

Tags binary exponentiation, math, #modular exponentiation

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en1 English cercatrova 2023-06-18 12:57:58 2894 Initial revision (published)