When submitting a solution in C++, please select either C++14 (GCC 6-32) or C++17 (GCC 7-32) as your compiler. ×

rembocoder's blog

By rembocoder, history, 22 months ago, In English

It seems, there are so many materials on binary search already that everyone must know how to code it, however every material seems to tell its own approach, and I see people being lost and still making bugs on such a simple thing. That's why I want to tell you about an approach which I find the most elegant. Of all the variety of articles, I saw my approach only in this comment (but less generalized).

UPD: Also in the nor's blog.

The problem

Suppose you want to find the last element less than x in a sorted array and just store a segment of candidates for the answer [l, r]. If you write:

int l = 0, r = n - 1; // the segment of candidates
while (l < r) { // there are more than 1 candidate
    int mid = (l + r) / 2;
    if (a[mid] < x) {
        l = mid;
    } else {
        r = mid - 1;
    }
}
cout << l;

...you will run into a problem: suppose l = 3, r = 4, then mid = 3. If a[mid] < x, you will end up in a loop (the next iteration will be again on [3, 4]).

Okay, it can be fixed with rounding mid up – mid = (l + r + 1) / 2. But we have another problem: what if there are no such elements in the array? We would need an extra check for that case. As a result, we have a pretty ugly code that is not generalized very well.

My approach

Let's generalize the problem we want to solve. We have a statement about an integer number n that is true for integers smaller than some bound, but then always stays false once n exceeded that bound. We want to find the last n for which the statement is true.

First of all, we will use half-intervals instead of segments (one border is inclusive, another is non-inclusive). Half-intrevals are in general very useful, elegant and conventional in programming, I recommend using them as much as possible. In that case we will choose some small l for which we know in before the statement is true and some big r for which we know it's false. Then the range of candidates is [l, r).

My code for that problem would be:

int l = -1, r = n; // a half-interval [l, r) of candidates
while (r - l > 1) { // there are more than 1 candidate
    int mid = (l + r) / 2;
    if (a[mid] < x) {
        l = mid; // now it's the largest for which we know it's true
    } else {
        r = mid; // now it's the smallest for which we know it's false
    }
}
cout << l; // in the end we are left with a range [l, l + 1)

The binary search will only do checks for some numbers strictly between initial l and r. It means that it will never check the statement for l and r, it will trust you that for l it's true, and for r it's false. Here we consider -1 as a valid answer, which will correspond to no numbers in array being less than x.

Notice how there are no "+ 1" or "- 1" in my code and no extra checks are needed, and no loops are possible (since mid is strictly between the current l and r).

Reverse problem

The only variation that you need to keep in mind is that half of the times you need to find not the last, but the first number for which something is true. In that case the statement must be always false for smaller numbers and always true starting from some number.

We will do pretty much the same thing, but now r will be an inclusive border, while l will be non-inclusive. In other words, l is now some number for which we know the statement to be false, and r is some for which we know it's true. Suppose I want to find the first number n for which n * (n + 1) >= x (x is positive):

int l = 0, r = x; // a half-interval (l, r] of candidates
while (r - l > 1) { // there are more than 1 candidate
    int mid = (l + r) / 2;
    if (mid * (mid + 1) >= x) {
        r = mid; // now it's the smallest for which we know it's true
    } else {
        l = mid; // now it's the largest for which we know it's false
    }
}
cout << r; // in the end we are left with a range (r - 1, r]

Just be careful to not choose a r too large, as it can lead to overflow.

An example

1201C - Maximum Median You are given an array a of an odd length n and in one operation you can increase any element by 1. What is the maximal possible median of the array that can be achieved in k steps?

Consider a statement about the number x: we can make the median to be no less than x in no more than k steps. Of course it is always true until some number, and then always false, so we can use binary search. As we need the last number for which this is true, we will use a normal half-interval [l, r).

To check for a given x, we can use a property of the median. A median is no less than x iff at least half of the elements are no less than x. Of course the optimal way to make half of the elements no less than x is to take the largest elements.

Of course we can reach the median no less than 1 under given constraints, so l will be equal to 1. But even if there is one element and it's equal to 1e9, and k is also 1e9, we still can't reach median 2e9 + 1, so r will be equal to 2e9 + 1. Implementation:

#define int int64_t

int n, k;
cin >> n >> k;
vector<int> a(n);
for (int i = 0; i < n; i++) {
    cin >> a[i];
}
sort(a.begin(), a.end());
int l = 1, r = 2e9 + 1; // a half-interval [l, r) of candidates
while (r - l > 1) {
    int mid = (l + r) / 2;
    int cnt = 0; // the number of steps needed
    for (int i = n / 2; i < n; i++) { // go over the largest half
        if (a[i] < mid) {
            cnt += mid - a[i];
        }
    }
    if (cnt <= k) {
        l = mid;
    } else {
        r = mid;
    }
}
cout << l << endl;

Conclusion

Hope I've made it clearer and some of you will switch to this implementation. To clarify, occasionally other implementations can be more fitting, for example with interactive problems – whenever we need to think in terms of an interval of searching, and not in terms of the first/last number for which something is true.

I remind that I do private lessons on competitive programming, the price is $30/h. Contact me on Telegram, Discord: rembocoder#3782, or in CF private messages.

  • Vote: I like it
  • +40
  • Vote: I do not like it

| Write comment?
»
22 months ago, # |
  Vote: I like it +10 Vote: I do not like it

Thanks for the blog!

Re: example task, you forgot to paste the ever important #define int int64_t from your old solution. :)

And on the topic of binsearch, there is also the binary jumps perspective where you (in a way...) don't explicitly remember neither right nor mid. I don't tend to use it, but perhaps it may appeal to somebody more than other methods.

For the example task, the relevant part could go:

code

Submission: 158749229

»
22 months ago, # |
  Vote: I like it +31 Vote: I do not like it

In my opinion the "Reverse problem" part is excessive. It basically replicates the previous part approach in terms of code. The very same code finds both: l is the last element for which the condition takes one value, and r is the first element for which the condition takes another value.

  • »
    »
    22 months ago, # ^ |
    Rev. 2   Vote: I like it 0 Vote: I do not like it

    You don't even have to confuse yourself with different types of intervals. This approach makes so much more sense.

»
22 months ago, # |
  Vote: I like it +54 Vote: I do not like it

Thanks for the blog! I used the same implementation before C++20, and I had a completely different interpretation, so it's nice to learn a new way of thinking about it. What I had as my invariant was the following (and it seems a bit more symmetric than in this blog):

Suppose we have a predicate $$$p$$$ that returns true for some prefix of $$$[L, R]$$$, and false for the remaining suffix. Then at any point in the algorithm:

  1. $$$l$$$ is the rightmost element for which you can prove that $$$f(l)$$$ is true.
  2. $$$r$$$ is the leftmost element for which you can prove that $$$f(r)$$$ is false.
  3. The remaining interval to search in is $$$(l, r)$$$.

$$$r - l > 1$$$ corresponds to having at least one unexplored element, and when the algorithm ends, $$$l, r$$$ are the positions of the rightmost true and the leftmost false. Intuitively, it is trying to look at positions like l]...[r, and trying to bring the ] and [ together, to find the "breaking point" (more details in my blog).

  • »
    »
    22 months ago, # ^ |
      Vote: I like it 0 Vote: I do not like it

    I was going through previous posts on binary search and missed that in the middle of your post you introduce the same implementation, as me. Sorry, I must be more attentive.

    As for your comment – yes, I also state this in the comments of my code. But with your and SirShokoladina's comments I now see that such formulation of moving the borders together can be even more obvious for a beginner.

»
22 months ago, # |
  Vote: I like it -11 Vote: I do not like it

tl;dr: half-open intervals are superior.

»
22 months ago, # |
  Vote: I like it -13 Vote: I do not like it

we can do the same with doubles but instead of while(r-l>1)

we must use while(r-l>eps)

  • »
    »
    22 months ago, # ^ |
      Vote: I like it +55 Vote: I do not like it

    This is known to be bad due to floating point precision issues. For floating point numbers, use a fixed number of iterations instead.

    • »
      »
      »
      22 months ago, # ^ |
        Vote: I like it 0 Vote: I do not like it

      or use Um_nik's bs

      while(min(r-l, r/l-1) > er_eps) {
      	T mid = r<=1 ? (l+r)/2 : l>=1 ? sqrt(l*r) : 1;
      	if(can(mid)) l = mid; else r = mid;
      }
      
»
22 months ago, # |
  Vote: I like it +29 Vote: I do not like it

In the example task your code only works because of define int int64_t. Indeed, if the answer is at least 1e9, then during the second iteration you do mid = (l + r) / 2, which, being done in int32_ts, would result in 3e9 being overflown, and then mid becoming negative.

To avoid this, one can, indeed, work with 64-bit type, or use mid = l + (r - l) / 2, which is arithmetically the same, but does not overflow. Since C++20 there is also std::midpoint, which does the same.

  • »
    »
    22 months ago, # ^ |
    Rev. 3   Vote: I like it 0 Vote: I do not like it

    $$$l + (r - l) / 2$$$ overflows for $$$l = \mathtt{INT\_MIN}$$$ and $$$r = \mathtt{INT\_MAX}$$$. A way that works for $$$l \le r$$$ is (which you might have needed before C++20 came in with std::midpoint, and which might even be faster):

    std::int32_t m = l + static_cast<std::int32_t>((static_cast<std::uint32_t>(r) - static_cast<std::uint32_t>(l))/ 2);
    

    or the more readable, but perhaps less "industry" standard:

    int32_t m = l + int32_t((uint32_t(r) - uint32_t(l)) / 2);
    

    Why this works: After replacing $$$l, r$$$ by their values modulo $$$2^{32}$$$ by doing a static cast to std::uint32_t, it is guaranteed that the difference $$$d$$$ between them is equal to $$$r - l$$$ modulo $$$2^{32}$$$. Since $$$l \le r$$$, $$$r - l$$$ has to be non-negative, and the difference between them is less than $$$2^{32}$$$, the required difference is indeed equal to $$$d$$$. Now $$$d / 2$$$ fits in a std::int32_t, so we can safely cast back to std::int32_t, and carry on with our computation of $$$m$$$.

    Note that the static_cast to unsigned is implementation defined before C++20, but for pretty much every implementation of GCC, it is twos-complement, so that works out practically.

    The same thing can be done with 32 replaced by 64 everywhere, making the code independent of 128-bit types.

    Also since we're already on the topic of overflow, it is a bit unfortunate that the binary search can be only done where the range of possible values (that is, where we can evaluate the predicate, rather than the return value) is $$$(\mathtt{INT\_MIN}, \mathtt{INT\_MAX})$$$, so we miss out the integers $$$\mathtt{INT\_MIN}, \mathtt{INT\_MAX}$$$. I am not completely sure that this is a bad thing either, though.

    Edit: similarly, the condition $$$r - l > 1$$$ suffers from overflows. We can fix it by simply checking if $$$l \ne \mathtt{INT\_MAX}$$$ and $$$l + 1 < r$$$, or just uint32_t(r) - uint32_t(l) > 1 (the starting $$$l, r$$$ should satisfy $$$r \ge l$$$ for the second idea to work).

»
22 months ago, # |
  Vote: I like it +83 Vote: I do not like it

Since we are talking about binary search, let me remind everyone what the best implementation is.

int ans = 0;
for (int k = 1 << 29; k != 0; k /= 2) {
  if (can(ans + k)) {
    ans += k;
  }
}
»
22 months ago, # |
  Vote: I like it 0 Vote: I do not like it

I'll notice that it's only possible when there's sentinel value that can never be returned. If you are implementing analogue of lower_bound you need to get prev(begin) or you have to switch to other implementations (adding one basically)