题意:给两个升序数组 AB,长度分别是 mn,求合并后的中位数。
本题中位数定义为下标 (N-1)/2 的元素(N = m+n),也就是“下中位数”。

例如:

  • [1,2][3,4] 合并后是 [1,2,3,4],答案是 2

1. 转化

目标是求第 k 小元素:

  • k = (m+n+1)/2(1-index)

所以问题变成:在两个有序数组中找第 k 小

2. 做法(每次丢掉一半)

设当前还在考虑:

  • A[ia..]
  • B[ib..]
  • 目标第 k

每轮做:

  1. half = k/2
  2. 比较 A[ia + half - 1]B[ib + half - 1](越界就取到数组末尾)
  3. 较小的一侧,不可能包含第 k 小之前更多有用元素,可整体丢弃
  4. 更新 k 和对应起点

边界:

  • 某个数组耗尽,答案直接在另一个数组里
  • k==1 时答案是两侧当前最小值

复杂度:

  • 时间:O(log(m+n))
  • 空间:O(1)(不计输入数组)

3. 代码(C++17)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;

int kthSmallest(const vector<int>& a, const vector<int>& b, int k) {
int ia = 0, ib = 0;
int na = (int)a.size(), nb = (int)b.size();

while (true) {
if (ia >= na) return b[ib + k - 1];
if (ib >= nb) return a[ia + k - 1];
if (k == 1) return min(a[ia], b[ib]);

int half = k >> 1;
int newIa = min(ia + half, na) - 1;
int newIb = min(ib + half, nb) - 1;

if (a[newIa] <= b[newIb]) {
k -= (newIa - ia + 1);
ia = newIa + 1;
} else {
k -= (newIb - ib + 1);
ib = newIb + 1;
}
}
}

int main() {
int m, n;
while (scanf("%d%d", &m, &n) == 2) {
vector<int> a(m), b(n);
for (int i = 0; i < m; ++i) scanf("%d", &a[i]);
for (int i = 0; i < n; ++i) scanf("%d", &b[i]);

int k = (m + n + 1) / 2; // 下中位数(1-index)
int ans = kthSmallest(a, b, k);
printf("%d\n", ans);
}
return 0;
}

4. 易错点

  1. 本题是“下中位数”,不是偶数长度时取均值。
  2. 多组输入,记得 while (scanf(...) == 2)
  3. half 取值后下标要处理越界,不能直接 ia + half - 1 不判断。

5. 补一版:二分划分(partition)写法

这个写法同样是 O(log(m+n)),核心是:

  • 设从 Ai 个、从 Bj=k-i
  • 调整 i,让左半部分都不大于右半部分
  • 满足条件时,答案是 max(A[i-1], B[j-1])

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
using namespace std;

int kthByPartition(const vector<int>& A, const vector<int>& B, int k) {
const vector<int> *pa = &A, *pb = &B;
int n = (int)A.size(), m = (int)B.size();

if (n > m) {
swap(n, m);
swap(pa, pb);
}

int L = max(0, k - m), R = min(k, n);
while (L <= R) {
int i = (L + R) >> 1;
int j = k - i;

int aL = (i == 0 ? INT_MIN : (*pa)[i - 1]);
int aR = (i == n ? INT_MAX : (*pa)[i]);
int bL = (j == 0 ? INT_MIN : (*pb)[j - 1]);
int bR = (j == m ? INT_MAX : (*pb)[j]);

if (aL <= bR && bL <= aR) {
return max(aL, bL);
}
if (aL > bR) {
R = i - 1;
} else {
L = i + 1;
}
}

return -1;
}

int main() {
int m, n;
while (scanf("%d%d", &m, &n) == 2) {
vector<int> a(m), b(n);
for (int i = 0; i < m; ++i) scanf("%d", &a[i]);
for (int i = 0; i < n; ++i) scanf("%d", &b[i]);

int k = (m + n + 1) / 2; // 下中位数(1-index)
printf("%d\n", kthByPartition(a, b, k));
}
return 0;
}