Please, try EDU on Codeforces! New educational section with videos, subtitles, texts, and problems. ×

Memoization with returning tuples does not work in Python

Revision en1, by beginner1010, 2019-09-11 20:58:39

Hello,

I tried problem Segment Sum. The DP solution for this problem was pretty straightforward to me. I decided to implement it in Python (You may find the python code below in this post).

I used memorization technique, and the recursive function returns two numbers, i.e. my DP table maintains two values: current summation and number of ways under the mentioned criteria in our current state. After hours of debugging, I couldn't find any problem with my code. I decided to remove if dp[smaller][start][pos][mask][0] != -1: return dp[smaller][start][pos][mask] to see if there is any issue with the DP table. Surprisingly, it output correct results when I removed these two lines. It seems there is something wrong with returning tuples (or array of size 2 here) from a recursive function in Python.

To make sure that the method is correct, I reimplemented it in C++ and it got Accepted, as expected: 60403673. Could you please help me fix the issue in Python?

Python code:

import sys

mod = 998244353
MAX_LENGTH = 20
bound = [0] * MAX_LENGTH

def mul(a, b): return (a * b) % mod
def add(a, b):
    a += b
    if a < 0: a += mod
    if a >= mod: a -= mod
    return a

def digitize(num):
    for i in range(MAX_LENGTH):
        bound[i] = num % 10
        num //= 10

def rec(smaller, start, pos, mask):
    global k
    if bit_count[mask] > k:
        return [0, 0]
    if pos == -1:
        return [0, 1]

    # if the two following lines are removed, the code reutrns correct results
    if dp[smaller][start][pos][mask][0] != -1:  
        return dp[smaller][start][pos][mask]

    res_sum = res_ways = 0
    for digit in range(0, 10):
        if smaller == 0 and digit > bound[pos]:
            continue
        new_smaller = smaller | (digit < bound[pos])
        new_start = start | (digit > 0) | (pos == 0)
        new_mask = (mask | (1 << digit)) if new_start == 1 else 0

        cur_sum, cur_ways = rec(new_smaller, new_start, pos - 1, new_mask)
        res_sum = add(res_sum, add(mul(mul(digit, ten_pow[pos]), cur_ways), cur_sum))
        res_ways = add(res_ways, cur_ways)

    dp[smaller][start][pos][mask][0], dp[smaller][start][pos][mask][1] = res_sum, res_ways
    return dp[smaller][start][pos][mask]

def solve(upper_bound):
    global dp
    dp = 2 * [2 * [MAX_LENGTH * [(1 << 10) * [[-1, -1]]]]]
    digitize(upper_bound)
    ans = rec(0, 0, MAX_LENGTH - 1, 0)
    print(ans)
    return ans[0]

inp = [int(x) for x in sys.stdin.read().split()]
l, r, k = inp[0], inp[1], inp[2]

bit_count = [0] * (1 << 10)
for i in range(1, 1 << 10): bit_count[i] = bit_count[i & (i - 1)] + 1
ten_pow = [(10 ** i) % mod for i in range(0, MAX_LENGTH)]

print(add(solve(r), -solve(l - 1)))

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en1 English beginner1010 2019-09-11 20:58:39 2973 Initial revision (published)