dmkz's blog

By dmkz, history, 4 years ago, translation, In English

In this blog you can found an efficient implementation of Karatsuba multiply of two polinomials, that sometimes can be used instead of FFT

$$$\text{ }$$$

Hello everyone!

One man ask me to write a code for calculation of $$$1.000.000$$$ fibonacci number in C++. This broblem has been solved in time less than 1 second with Karatsuba multiply. After it, I tried to use this code in standart problems like "we have two strings of letters $$$\text{ATGC}$$$ and want to calculate optimal cyclic rotation of one of them so number of equal positions will be greatest" with Karatsuba multiply, and these problems have been successfully solved. There are a lot of naive implementations, but MrDindows helped to write a most efficient of all simple implementations.

Idea of Karatsuba multiply

We have two polinomials $$$a(x)$$$ and $$$b(x)$$$ with equal length $$$2n$$$ and want multiply them. Let suppose that $$$a(x) = a_0(x) + x^n \cdot a_1(x)$$$ and $$$b(x) = b_0(x) + x^n \cdot b_1(x)$$$. Now we can calculate result of $$$a(x) \cdot b(x)$$$ in next steps:

  • Calculate $$$E(x) = (a_0(x) + a_1(x)) \cdot (b_0(x) + b_1(x))$$$
  • Calculate $$$r_0(x)=a_0(x)\cdot b_0(x)$$$
  • Calculate $$$r_2(x)=a_1(x) \cdot b_1(x)$$$
  • Now, result is $$$a(x) \cdot b(x) = r_0(x) + x^n \cdot \left(E(x) - r_0(x) - r_2(x)\right) + x^{2n} \cdot r_2(x)$$$

We can see, that instead of $$$4$$$ multiplications we used only $$$3$$$, so we have time $$$O(n^{\log_2(3)})$$$ or $$$O(n^{1.58})$$$

Efficient implementation

#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx,avx2,fma") 
namespace {
    template<int n, typename T>
    void mult(const T *__restrict a, const T *__restrict b, T *__restrict res) {
        if (n <= 64) { // if length is small then naive multiplication if faster
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    res[i + j] += a[i] * b[j];
                }
            }
        } else {
            const int mid = n / 2;
            alignas(64) T btmp[n], E[n] = {};
            auto atmp = btmp + mid;
            for (int i = 0; i < mid; i++) {
                atmp[i] = a[i] + a[i + mid]; // atmp(x) - sum of two halfs a(x)
                btmp[i] = b[i] + b[i + mid]; // btmp(x) - sum of two halfs b(x)
            }
            mult<mid>(atmp, btmp, E); // Calculate E(x) = (alow(x) + ahigh(x)) * (blow(x) + bhigh(x))
            mult<mid>(a + 0, b + 0, res); // Calculate rlow(x) = alow(x) * blow(x)
            mult<mid>(a + mid, b + mid, res + n); // Calculate rhigh(x) = ahigh(x) * bhigh(x)
            for (int i = 0; i < mid; i++) { // Then, calculate rmid(x) = E(x) - rlow(x) - rhigh(x) and write in memory
                const auto tmp = res[i + mid];
                res[i + mid] += E[i] - res[i] - res[i + 2 * mid];
                res[i + 2 * mid] += E[i + mid] - tmp - res[i + 3 * mid];
            }
        }
    }
}

Testing

Example of solutions for real problems: 528D. Fuzzy Search and AtCoder Beginner Contest 149 E. Handshake.

Stress-test 512K for 64-bit and 32-bit coefficients

It is possible that this code can be used instead of FFT, but be careful: for multiplication of two polinomials with length $$$2^{19}$$$ Karatsuba multiply requires $$$3.047.478.784$$$ operations. Of cource, it can be executed in 1 second, but for greater length...

  • Vote: I like it
  • +136
  • Vote: I do not like it

»
4 years ago, # |
  Vote: I like it +5 Vote: I do not like it

Don't forget that you also need modulo for most problems where polynomial multiplication is used.

  • »
    »
    4 years ago, # ^ |
      Vote: I like it 0 Vote: I do not like it

    This is one more reason to use Karastuba mult It works good with "bad" fft modulo

    • »
      »
      »
      4 years ago, # ^ |
        Vote: I like it 0 Vote: I do not like it

      Yeah, that's one of the big reasons why I prefer Karatsuba over FFT in most situations. Just pointing out that it matters a bit for benchmarks and that you can't use this exact code right away because of overflows.

»
4 years ago, # |
  Vote: I like it +43 Vote: I do not like it

Can you make a benchmark comparing it with FFT?

  • »
    »
    4 years ago, # ^ |
    Rev. 2   Vote: I like it +26 Vote: I do not like it

    Code for testing

    Codeforces results
»
4 years ago, # |
  Vote: I like it 0 Vote: I do not like it

What is the reasoning behind allocating the temporary arrays inside the function body? Isn't there a problem with big array allocations on the stack making more cache misses (and with possible stack size limitations of other OJs)? What happens to the benchmarks if you preallocate those arrays?

  • »
    »
    4 years ago, # ^ |
      Vote: I like it 0 Vote: I do not like it

    Cache misses have nothing to do with stack usage. It's the order of data accesses that matters. As an extreme example, if you allocate a 1 GB array on the stack, but don't access it at all, then cache hits/misses are exactly the same because the "allocation" is only a change in upper bits in the stack register.

    In fact, the fixed nature of this recursion (let's assume sizes $$$2^n$$$, it doesn't hurt) means that this "allocation" inside functions really is just giving purpose to parts of the stack. When we're at depth $$$d$$$, the locations of these arrays are always the same regardless of the branch of recursion, and when we move to an adjacent depth, these arrays are also adjacent. That means it's pretty good cache-wise at high depths, where we have a lot of recursion.

    • »
      »
      »
      4 years ago, # ^ |
        Vote: I like it 0 Vote: I do not like it

      That makes sense, I was more referring to the fact that if you allocate variables both before the array allocation and after it, the ones after will not be colocated with the ones before, so to speak. That might not be a problem at all (I'm not sure, to be honest), but still, some configurations limit the stack size and therefore huge arrays allocated on the stack might be problematic (?)

      • »
        »
        »
        »
        4 years ago, # ^ |
          Vote: I like it +10 Vote: I do not like it

        Yeah, with 8 MB stacks, this can fail. Easy fix: allocate one chunk on the heap and assign just pointers to each recursion level, it'll do the exact same thing. ¯\_(ツ)_/¯

        In Karatsuba, cache isn't much of a bottleneck though. It can be if you jump around in memory too much at high depths, but that's not the case here. Just look at those beautiful linear passes. That's where most work is done and the bottleneck will be mostly the number of add/multiply operations and how well they can be pipelined. Even more so when modulo is involved.

»
4 years ago, # |
  Vote: I like it 0 Vote: I do not like it

More optimizations: 1. Using Comba multiplier instead of naive multiplication (around 1.35 times faster) 2. Pre-allocating 4*N space up front to hold the temporaries instead of stack-based allocation.

»
4 years ago, # |
Rev. 3   Vote: I like it +8 Vote: I do not like it

Mine is slower (1s for multiplying vectors of size $$$200\,000$$$) but can handle different sizes of two arrays and it doesn't use vectorization so maybe can be improved nicely too. The complexity is $$$O(max(s1, s2) \cdot min(s1, s2)^{0.6})$$$. It works better than FFT for arrays smaller than maybe $$$30\,000$$$ and/or if sizes differ a lot.

code
  • »
    »
    4 years ago, # ^ |
    Rev. 2   Vote: I like it 0 Vote: I do not like it

    Thanks for sharing your implementation. On your testcase ($$$200123$$$ random ints from $$$0$$$ to $$$9$$$) implementation from blog with static arrays works in $$$218\text{ ms}$$$

    code

    I will think about template for not equal lengths both powers of two.

»
4 years ago, # |
  Vote: I like it 0 Vote: I do not like it

Wow, nice <3

»
12 months ago, # |
  Vote: I like it -18 Vote: I do not like it

How could this code be modified to work in some modulo? I've tried to modify it for this problem about polynomial multiplication, but got TLE (you can find the code for my submission here). In this case using int alone gives WA because of overflow, so I had to use an intermediate casting to long long before applying the modulo.