Блог пользователя rembocoder

Автор rembocoder, история, 2 года назад, перевод, По-русски

Казалось бы, по бинпоиску написано столько материалов, что все должны знать, как писать его, однако в каждом материале похоже рассказывается свой подход, и я вижу, что люди теряются и допускают ошибки в такой простой штуке. Поэтому я хочу рассказать о подходе, который нахожу наиболее элегантным. Из всего многообразия статей, я увидел свой подход только в этом комментарии (но менее обобщённый).

UPD: Также у nor в блоге.

Проблема

Предположим, вы хотите найти последний элемент в отсортированном массиве, меньший x, и просто храните отрезок кандидатов на ответ [l, r]. Если вы напишете:

int l = 0, r = n - 1; // отрезок кандидатов
while (l < r) { // больше одного кандидата
    int mid = (l + r) / 2;
    if (a[mid] < x) {
        l = mid;
    } else {
        r = mid - 1;
    }
}
cout << l;

...вы наткнётесь на проблему: предположим, l = 3, r = 4, тогда mid = 3. Если a[mid] < x, вы попадёте в цикл (на следующей итерации отрезок будет снова [3, 4]).

Ладно, это можно исправить округлением mid вверх – mid = (l + r + 1) / 2. Но есть ещё одна проблема: что если в массиве нет таких элементов? Нам потребуется дополнительная проверка на этот случай. В результате, мы получили довольно уродливый код, который тяжело обобщить.

Мой подход

Давайте обобщим нашу задачу. У нас есть утверждение о целом числе n, которое истинно для целых чисел меньше некоторой границы, но потом всегда остаётся ложным как толькоn превысит её. Мы хотим найти последнееn, для которого утверждение истинно.

Прежде всего, мы будем использовать полуинтервалы вместо отрезков (одна граница включительная, другая – нет). Полуинтервалы вообще очень полезны, элегантны и общеприняты в программировании, я рекомендую использовать их по возможности. В этом случае мы выбираем некое маленькое l, для которого мы знаем, что утверждение истинно и некое большое r, для которого мы знаем, что оно ложно. Тогда диапазоном кандидатов будет [l, r).

Моим код для этой задачи было бы:

int l = -1, r = n; // полуинтервал [l, r) кандидатов
while (r - l > 1) { // больше одного кандидата
    int mid = (l + r) / 2;
    if (a[mid] < x) {
        l = mid; // теперь это наибольшее, для которого мы знаем, что это истинно
    } else {
        r = mid; // теперь это наименьшее, для которого мы знаем, что это ложно
    }
}
cout << l; // в конце остаётся диапозон [l, l + 1)

Бинпоиск будет проверять только числа строго между изначальными l и r. Это значит, что он никогда не проверит условие для l и r, он поверит вам, что для l это истинно, а для rложно. Здесь мы считаем -1 корректным ответом, который соответствует случаю, когда ни один элемент массива не меньше x.

Заметьте, что у меня в коде нет "+ 1" или "- 1", и мне не нужны дополнительные проверки, и попадание в цикл невозможно (т. к. mid строго между текущими l и r).

Обратная задача

Единственная вариация, которую вам надо держать в голове – это что в половине случаев вам нужно будет найти не последнее, а первое число, для которого нечто выполняется. В этом случае утверждение должно быть всегда ложно для меньших чисел и всегда истинно, начиная с некоторого числа.

Мы будем делать примерно то же самое, но теперьr будет включительной границей, а l – не включительной. Другими словами, l теперь – некоторое число, для которого мы знаем, что утверждение ложно, а r – некоторое, для которого мы знаем, что оно истинно. Предположим, я хочу найти первое число n, для которогоn * (n + 1) >= x (x положительно):

int l = 0, r = x; // полуинтервал (l, r] кандидатов
while (r - l > 1) { // больше 1 кандидата
    int mid = (l + r) / 2;
    if (mid * (mid + 1) >= x) {
        r = mid; // теперь это наименьшее, для которого мы знаем, что это истинно
    } else {
        l = mid; // теперь это наибольшее, для которого мы знаем, что это ложно
    }
}
cout << r; // в конце остаётся диапозон (r - 1, r]

Только будьте осторожны с выбором r, слишком большое может привести к переполнению.

Пример

1201C - Максимальная медиана Вам дан массив a нечётной длины n и за один шаг можно увеличить один элемент на 1. Какую наибольшую медиану массива можно достичь за k шагов?

Рассмотрим утверждение о числе x: мы можем сделать медиану не меньше x за не более, чемk шагов. Конечно, это всегда истинно до некоторого числа, а затем всегда ложно, поэтому можно использовать бинпоиск. Так как нам нужно последнее число, для которого это истинно, мы будем использовать обычный полуинтервал [l, r).

Чтобы проверить данное x, мы можем воспользоваться свойством медианы. Медиана не меньше x тогда и только тогда, когда не менее половины всех элементов не меньше x. Конечно, оптимальным способом добиться этого будет взять наибольшую половину элементов и сделать их не меньше x.

Конечно, можно достичь медиану не менее 1 при данных ограничениях, так что l будет равно1. Но даже если есть только один элемент, равный 1e9 и k тоже равно1e9, мы всё равно не сможем добиться медианы 2e9 + 1, так что r будет равно 2e9 + 1. Реализация:

#define int int64_t

int n, k;
cin >> n >> k;
vector<int> a(n);
for (int i = 0; i < n; i++) {
    cin >> a[i];
}
sort(a.begin(), a.end());
int l = 1, r = 2e9 + 1; // полуинтервал [l, r) кандидатов
while (r - l > 1) {
    int mid = (l + r) / 2;
    int cnt = 0; // необходимое количество шагов
    for (int i = n / 2; i < n; i++) { // идём по наибольшей половине
        if (a[i] < mid) {
            cnt += mid - a[i];
        }
    }
    if (cnt <= k) {
        l = mid;
    } else {
        r = mid;
    }
}
cout << l << endl;

Заключение

Надеюсь, я разъяснил кое-что, и часть из вас переключится на эту имплементацию. Для уточнения, иногда другие имплементации могут оказаться более подходящими, например, в интерактивных задачах – когда нам нужно мыслить в терминах интервала поиска, а не в терминах первого/последнего числа, для которого что-то выполняется.

Я напоминаю, что даю частные уроки по спортивному программированию, цена – $30/ч. Пишите мне в Telegram, Discord: rembocoder#3782, или в личные сообщения на CF.

  • Проголосовать: нравится
  • +40
  • Проголосовать: не нравится

»
2 года назад, # |
  Проголосовать: нравится +10 Проголосовать: не нравится

Thanks for the blog!

Re: example task, you forgot to paste the ever important #define int int64_t from your old solution. :)

And on the topic of binsearch, there is also the binary jumps perspective where you (in a way...) don't explicitly remember neither right nor mid. I don't tend to use it, but perhaps it may appeal to somebody more than other methods.

For the example task, the relevant part could go:

code

Submission: 158749229

»
2 года назад, # |
  Проголосовать: нравится +31 Проголосовать: не нравится

In my opinion the "Reverse problem" part is excessive. It basically replicates the previous part approach in terms of code. The very same code finds both: l is the last element for which the condition takes one value, and r is the first element for which the condition takes another value.

  • »
    »
    2 года назад, # ^ |
    Rev. 2   Проголосовать: нравится 0 Проголосовать: не нравится

    You don't even have to confuse yourself with different types of intervals. This approach makes so much more sense.

»
2 года назад, # |
  Проголосовать: нравится +54 Проголосовать: не нравится

Thanks for the blog! I used the same implementation before C++20, and I had a completely different interpretation, so it's nice to learn a new way of thinking about it. What I had as my invariant was the following (and it seems a bit more symmetric than in this blog):

Suppose we have a predicate $$$p$$$ that returns true for some prefix of $$$[L, R]$$$, and false for the remaining suffix. Then at any point in the algorithm:

  1. $$$l$$$ is the rightmost element for which you can prove that $$$f(l)$$$ is true.
  2. $$$r$$$ is the leftmost element for which you can prove that $$$f(r)$$$ is false.
  3. The remaining interval to search in is $$$(l, r)$$$.

$$$r - l > 1$$$ corresponds to having at least one unexplored element, and when the algorithm ends, $$$l, r$$$ are the positions of the rightmost true and the leftmost false. Intuitively, it is trying to look at positions like l]...[r, and trying to bring the ] and [ together, to find the "breaking point" (more details in my blog).

  • »
    »
    2 года назад, # ^ |
      Проголосовать: нравится 0 Проголосовать: не нравится

    I was going through previous posts on binary search and missed that in the middle of your post you introduce the same implementation, as me. Sorry, I must be more attentive.

    As for your comment – yes, I also state this in the comments of my code. But with your and SirShokoladina's comments I now see that such formulation of moving the borders together can be even more obvious for a beginner.

»
2 года назад, # |
  Проголосовать: нравится -11 Проголосовать: не нравится

tl;dr: half-open intervals are superior.

»
2 года назад, # |
  Проголосовать: нравится -13 Проголосовать: не нравится

we can do the same with doubles but instead of while(r-l>1)

we must use while(r-l>eps)

  • »
    »
    2 года назад, # ^ |
      Проголосовать: нравится +55 Проголосовать: не нравится

    This is known to be bad due to floating point precision issues. For floating point numbers, use a fixed number of iterations instead.

»
2 года назад, # |
  Проголосовать: нравится +29 Проголосовать: не нравится

In the example task your code only works because of define int int64_t. Indeed, if the answer is at least 1e9, then during the second iteration you do mid = (l + r) / 2, which, being done in int32_ts, would result in 3e9 being overflown, and then mid becoming negative.

To avoid this, one can, indeed, work with 64-bit type, or use mid = l + (r - l) / 2, which is arithmetically the same, but does not overflow. Since C++20 there is also std::midpoint, which does the same.

  • »
    »
    2 года назад, # ^ |
    Rev. 3   Проголосовать: нравится 0 Проголосовать: не нравится

    $$$l + (r - l) / 2$$$ overflows for $$$l = \mathtt{INT\_MIN}$$$ and $$$r = \mathtt{INT\_MAX}$$$. A way that works for $$$l \le r$$$ is (which you might have needed before C++20 came in with std::midpoint, and which might even be faster):

    std::int32_t m = l + static_cast<std::int32_t>((static_cast<std::uint32_t>(r) - static_cast<std::uint32_t>(l))/ 2);
    

    or the more readable, but perhaps less "industry" standard:

    int32_t m = l + int32_t((uint32_t(r) - uint32_t(l)) / 2);
    

    Why this works: After replacing $$$l, r$$$ by their values modulo $$$2^{32}$$$ by doing a static cast to std::uint32_t, it is guaranteed that the difference $$$d$$$ between them is equal to $$$r - l$$$ modulo $$$2^{32}$$$. Since $$$l \le r$$$, $$$r - l$$$ has to be non-negative, and the difference between them is less than $$$2^{32}$$$, the required difference is indeed equal to $$$d$$$. Now $$$d / 2$$$ fits in a std::int32_t, so we can safely cast back to std::int32_t, and carry on with our computation of $$$m$$$.

    Note that the static_cast to unsigned is implementation defined before C++20, but for pretty much every implementation of GCC, it is twos-complement, so that works out practically.

    The same thing can be done with 32 replaced by 64 everywhere, making the code independent of 128-bit types.

    Also since we're already on the topic of overflow, it is a bit unfortunate that the binary search can be only done where the range of possible values (that is, where we can evaluate the predicate, rather than the return value) is $$$(\mathtt{INT\_MIN}, \mathtt{INT\_MAX})$$$, so we miss out the integers $$$\mathtt{INT\_MIN}, \mathtt{INT\_MAX}$$$. I am not completely sure that this is a bad thing either, though.

    Edit: similarly, the condition $$$r - l > 1$$$ suffers from overflows. We can fix it by simply checking if $$$l \ne \mathtt{INT\_MAX}$$$ and $$$l + 1 < r$$$, or just uint32_t(r) - uint32_t(l) > 1 (the starting $$$l, r$$$ should satisfy $$$r \ge l$$$ for the second idea to work).

»
2 года назад, # |
  Проголосовать: нравится +83 Проголосовать: не нравится

Since we are talking about binary search, let me remind everyone what the best implementation is.

int ans = 0;
for (int k = 1 << 29; k != 0; k /= 2) {
  if (can(ans + k)) {
    ans += k;
  }
}
»
2 года назад, # |
  Проголосовать: нравится 0 Проголосовать: не нравится

I'll notice that it's only possible when there's sentinel value that can never be returned. If you are implementing analogue of lower_bound you need to get prev(begin) or you have to switch to other implementations (adding one basically)