Having trouble solving the problem Jewel-eating monster

Правка en1, от cercatrova, 2023-06-18 12:57:58

In the problem Jewel-eating Monsters if the traveler has x coin in the evening and drops one coin in the pond at midnight then in the morning they will have (x-1)*a coins. With the total coin if they buy diamonds where each costs c coins, remaining coins should be total_coins % c.

Applying the same logic if the traveler repeats the action for n nights,

Total_coins: 

(x-1)*a + ((x-1)*a-1)*a + (((x-1)*a-1)*a-1)*a + ...... + upto n nights
= x*a - a + x*a^2 - a^2 - a + .......+ x*a^n - a^n - a^(n-1) - .......... - a
=x*(a + a^2 + a^3 + ..... + a^n) - (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n)

Now for (a + a^2+ a^3 + ..... + a^n) = a*(a^n - 1)/(a-1)

For (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n),
Suppose,
Sn = n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n
Multiply a with both side,
a*Sn = n*a^2 + (n-1)*a^3 + .......... + 2*a^n + a^(n+1)

Now,
Sn - a*Sn = (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n) - (n*a^2 + (n-1)*a^3 + .......... + 2*a^n + a^(n+1))
=n*a - a^2 - a^3 - ........... - a^n - a^(n+1)
=(n+1)*a - a - a^2 - a^3 - ........... - a^n - a^(n+1)
=(n+1)*a - (a + a^2 + a^3 + ........... + a^n) - a^(n+1)
=(n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1)

Or, Sn(1-a) = (n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1)
Or, Sn = -((n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1))/(a-1)

Hence, Total_coins=x * a*(a^n - 1)/(a-1) - Sn

And remaining coins after buying diamonds = Total_coins % price_of_a_single_diamond


I tried to implement the same logic in my code. It is giving correct result for the first test case which is 357 but the rests are not matching. What mistakes am I making in the code or the math?
My code:

#include<bits/stdc++.h>
#define fast ios::sync_with_stdio(false)
using namespace std;

typedef long long ll;
ll mod_ex(ll a,ll b, ll mod){
    ll res=1;
    while(b){
        if(b%2){
            res=(res*a)%mod;
        }
        a=(a*a)%mod;
        b/=2;
    }
    return res;
}
int main()
{
    fast;
    ll x,a,n,c;
    while(true){
        cin>>x>>a>>n>>c;
        if(x+a+n+c==0)break;

        ll p=((a%c)*(mod_ex(a,n,c)-1))%c;
        // according to fermat's little theorem, (a/b)%c = (a * b^(c-2))%c when c is prime
        ll q=(p*mod_ex(a-1,c-2,c))%c;

        ll xq=(x*q)%c;
        ll yq=-(((n+1)*a-q-mod_ex(a,n+1,c))%c*mod_ex(a-1,c-2,c))%c;
        ll ans=(xq-yq)%c;
        cout<<ans<<endl;
    }
    return 0;
}

I am quite beginner in competitive programming. Any constructive criticism is appreciated.

Теги binary exponentiation, math, #modular exponentiation

История

 
 
 
 
Правки
 
 
  Rev. Язык Кто Когда Δ Комментарий
en1 Английский cercatrova 2023-06-18 12:57:58 2894 Initial revision (published)