@
Knuth 这个算法很优秀,但它是 O(n)级别的,如果 n 很大但是 m 很小会很慢。我以前搞出一个“模拟打乱算法”是 O(m)级别的,原理是模仿打乱算法,但是不真的做一个 O(n)的数组出来,而是用 hash 表代替“已经打乱过的区域”,只需要 O(m)的空间即可。在 m 很小但是 n 很大的时候会比较优,下面是 Go 的实现:
```golang
// SampleFilter 随机采样[0-N)中的 m 个,排除 f 返回 false 的元素.
func SampleFilter(n uint32, m uint32, f func(uint32) bool) []uint32 {
if n == 0 || m == 0 {
return nil
}
// 如果采样结果比长度还多,就直接顺序返回全部值
if m >= n {
indexes := make([]uint32, 0, n)
for i := uint32(0); i < n; i++ {
if f(i) {
indexes = append(indexes, i)
}
}
return indexes
}
// 否则,做一个虚拟索引映射表(表里不存在的就是原索引),走洗牌算法
indexMap := make(map[uint32]uint32, m)
indexes := make([]uint32, 0, m)
for i := uint32(0); i < n; i++ {
if uint32(len(indexes)) >= m {
return indexes
}
// 先获取一个随机索引的映射
ri := RangeRand(i, n-1)
mappedIdx, ok := indexMap[ri]
if !ok {
mappedIdx = ri
}
// 如果满足要求,就加入到结果中
if f(mappedIdx) {
indexes = append(indexes, mappedIdx)
}
// 在随机位置填入当前位置映射后的索引
indexMap[ri], ok = indexMap[i]
if !ok {
indexMap[ri] = i
}
}
return indexes
}
```