beginner1010's blog

By beginner1010, history, 5 years ago, In English

Hello,

I tried problem Segment Sum. The DP solution for this problem was pretty straightforward to me. I decided to implement it in Python (You may find the python code below in this post).

I used memorization technique, and the recursive function returns two numbers, i.e. my DP table maintains two values: current summation and number of ways under the mentioned criteria in our current state. After hours of debugging, I couldn't find any problem with my code. I decided to remove if dp[smaller][start][pos][mask][0] != -1: return dp[smaller][start][pos][mask] to see if there is any issue with the DP table. Surprisingly, it output correct results when I removed these two lines. It seems there is something wrong with returning tuples (or array of size 2 here) from a recursive function in Python.

To make sure that the method is correct, I reimplemented it in C++ and it got Accepted, as expected: 60403673. Could you please help me fix the issue in Python?

Python code:

import sys

mod = 998244353
MAX_LENGTH = 20
bound = [0] * MAX_LENGTH

def mul(a, b): return (a * b) % mod
def add(a, b):
    a += b
    if a < 0: a += mod
    if a >= mod: a -= mod
    return a

def digitize(num):
    for i in range(MAX_LENGTH):
        bound[i] = num % 10
        num //= 10

def rec(smaller, start, pos, mask):
    global k
    if bit_count[mask] > k:
        return [0, 0]
    if pos == -1:
        return [0, 1]

    # if the two following lines are removed, the code reutrns correct results
    if dp[smaller][start][pos][mask][0] != -1:  
        return dp[smaller][start][pos][mask]

    res_sum = res_ways = 0
    for digit in range(0, 10):
        if smaller == 0 and digit > bound[pos]:
            continue
        new_smaller = smaller | (digit < bound[pos])
        new_start = start | (digit > 0) | (pos == 0)
        new_mask = (mask | (1 << digit)) if new_start == 1 else 0

        cur_sum, cur_ways = rec(new_smaller, new_start, pos - 1, new_mask)
        res_sum = add(res_sum, add(mul(mul(digit, ten_pow[pos]), cur_ways), cur_sum))
        res_ways = add(res_ways, cur_ways)

    dp[smaller][start][pos][mask][0], dp[smaller][start][pos][mask][1] = res_sum, res_ways
    return dp[smaller][start][pos][mask]

def solve(upper_bound):
    global dp
    dp = 2 * [2 * [MAX_LENGTH * [(1 << 10) * [[-1, -1]]]]]
    digitize(upper_bound)
    ans = rec(0, 0, MAX_LENGTH - 1, 0)
    print(ans)
    return ans[0]

inp = [int(x) for x in sys.stdin.read().split()]
l, r, k = inp[0], inp[1], inp[2]

bit_count = [0] * (1 << 10)
for i in range(1, 1 << 10): bit_count[i] = bit_count[i & (i - 1)] + 1
ten_pow = [(10 ** i) % mod for i in range(0, MAX_LENGTH)]

print(add(solve(r), -solve(l - 1)))
  • Vote: I like it
  • 0
  • Vote: I do not like it

»
5 years ago, # |
  Vote: I like it +6 Vote: I do not like it

In your solve function, the * operator only copies the references to the lists. So:

>>> A = 2 * [[0, 0]]
>>> A
[[0, 0], [0, 0]]
>>> A[0][0] = 1
>>> A
[[1, 0], [1, 0]]

This might be the problem here. You should use something like A = [[0, 0] for i in range(2)] (or xrange in python2) instead.

  • »
    »
    5 years ago, # ^ |
      Vote: I like it +21 Vote: I do not like it

    There's also the thing that nested lists in Python are slow, so it's best not to use them. For example, a 4D list with shape $$$(256, 256, 256, 2)$$$ — 33 million elements — takes about 5 seconds to create for me (Python 3), while a 1D version with [-1] * (2*256**3) takes so little time I'd get MLE before TLE with that. When numpy arrays are an option (unfortunately not here), numpy arrays are the way to go. Otherwise, 1D list all the way. Regular nested lists are only good when sizes vary and that always has a performance cost, just like ArrayList in Java.

»
5 years ago, # |
  Vote: I like it 0 Vote: I do not like it

Yes! That was the problem. Thank you very much.