Please, try EDU on Codeforces! New educational section with videos, subtitles, texts, and problems. ×

Memoization with returning tuples does not work in Python

Revision en1, by beginner1010, 2019-09-11 20:58:39

Hello,

I tried problem Segment Sum. The DP solution for this problem was pretty straightforward to me. I decided to implement it in Python (You may find the python code below in this post).

I used memorization technique, and the recursive function returns two numbers, i.e. my DP table maintains two values: current summation and number of ways under the mentioned criteria in our current state. After hours of debugging, I couldn't find any problem with my code. I decided to remove if dp[smaller][start][pos][mask] != -1: return dp[smaller][start][pos][mask] to see if there is any issue with the DP table. Surprisingly, it output correct results when I removed these two lines. It seems there is something wrong with returning tuples (or array of size 2 here) from a recursive function in Python.

To make sure that the method is correct, I reimplemented it in C++ and it got Accepted, as expected: 60403673. Could you please help me fix the issue in Python?

Python code:

import sys

mod = 998244353
MAX_LENGTH = 20
bound =  * MAX_LENGTH

def mul(a, b): return (a * b) % mod
a += b
if a < 0: a += mod
if a >= mod: a -= mod
return a

def digitize(num):
for i in range(MAX_LENGTH):
bound[i] = num % 10
num //= 10

global k
return [0, 0]
if pos == -1:
return [0, 1]

# if the two following lines are removed, the code reutrns correct results

res_sum = res_ways = 0
for digit in range(0, 10):
if smaller == 0 and digit > bound[pos]:
continue
new_smaller = smaller | (digit < bound[pos])
new_start = start | (digit > 0) | (pos == 0)
new_mask = (mask | (1 << digit)) if new_start == 1 else 0

cur_sum, cur_ways = rec(new_smaller, new_start, pos - 1, new_mask)

def solve(upper_bound):
global dp
dp = 2 * [2 * [MAX_LENGTH * [(1 << 10) * [[-1, -1]]]]]
digitize(upper_bound)
ans = rec(0, 0, MAX_LENGTH - 1, 0)
print(ans)
return ans

inp = [int(x) for x in sys.stdin.read().split()]
l, r, k = inp, inp, inp

bit_count =  * (1 << 10)
for i in range(1, 1 << 10): bit_count[i] = bit_count[i & (i - 1)] + 1
ten_pow = [(10 ** i) % mod for i in range(0, MAX_LENGTH)]



#### History

Revisions Rev. Lang. By When Δ Comment
en1 beginner1010 2019-09-11 20:58:39 2973 Initial revision (published)