// E - Distance on Large Perfect Binary Tree
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ld long double
#define ar array
#define fi first
#define se second
#define ii pair <int, int>
#define vt vector
#define pb push_back
#define sz(x) (int)(x).size()
#define reset(a, x) memset(a, x, sizeof(a))
#define END cerr << "\n";
#define endl "\n"
#define error(args...) { \
string _s = #args; \
replace(_s.begin(), _s.end(), ',', ' '); \
stringstream _ss(_s); \
istream_iterator <string> _it(_ss); \
err(_it, args); \
}
void err(istream_iterator<string> it) {}
template<typename T, typename... Args>
void err(istream_iterator<string> it, T a, Args... args) {
cerr << *it << "=" << a << ", ";
err(++it, args...);
}
template<class T> void read(T& x) {
cin >> x;
}
template<class H, class... T> void read(H& h, T&... t) {
read(h);
read(t...);
}
template<class A> void write(A x) {
cout << x;
}
template<class H, class... T> void write(const H& h, const T&... t) {
write(h);
write(t...);
}
void print() {
write("\n");
}
template<class H, class... T> void print(const H& h, const T&... t) {
write(h);
if(sizeof...(t))
write(' ');
print(t...);
}
#define file ""
inline void setIO() {
//freopen(file".out", "w", stdout); freopen(file".in", "r", stdin);
cin.tie(0)->sync_with_stdio(0);
}
const long long INF = 2e18;
const int inf = 0x3c3c3c3c, N = 1e6, D = 2e6+1, mod = 998244353;
int n, d, pow2[D], dp[D];
int mul(int a, int b) {
return (1ll * a * b) % mod;
}
int add(int a, int b) {
a += b;
if (a >= mod) a -= mod;
return a;
}
void solve() {
read(n, d);
pow2[0] = 1;
for (int i = 1; i <= d; i++) pow2[i] = mul(pow2[i-1], 2);
int lim = (d % 2 ? (d + 1) / 2 : d / 2);
for (int i = lim; i < d; i++) {
dp[i] = mul(pow2[i-1], pow2[d-i-1]);
dp[i] = mul(dp[i], 2);
dp[i] = add(dp[i], dp[i-1]);
}
//error(dp[1]);
int ans = 0;
for (int i = 0; i < n; i++) {
int curAns = (i + d < n ? mul(pow2[d], 2) : 0);
int curLim = min(d-1, n-i-1);
if (curLim < lim) break;
curAns = add(curAns, dp[curLim]);
curAns = mul(curAns, pow2[i]);
ans = add(curAns, ans);
//cerr << ans << endl;
}
write(ans);
}
int32_t main() {
setIO();
int Case = 1;
//read(Case);
for (int ttest = 1; ttest <= Case; ttest++) {
//write("Case ", ttest, ": ");
solve();
}
}