fix quickselect.

This commit is contained in:
fiatjaf
2025-08-04 17:29:12 -03:00
parent 2750ae3751
commit 1cd48343d6
3 changed files with 120 additions and 45 deletions

View File

@@ -93,39 +93,76 @@ func (it *iterator) next() {
type iterators []iterator
// quickselect reorders the slice just enough to make the top k elements be arranged at the end
// i.e. [1, 700, 25, 312, 44, 28] with k=3 becomes something like [28, 25, 1, 44, 312, 700]
// i.e. [1, 700, 25, 312, 44, 28] with k=3 becomes something like [700, 312, 44, 1, 25, 28]
// in this case it's hardcoded to use the 'last' field of the iterator
func (its iterators) quickselect(left int, right int, k int) {
if right == left {
return
// copied from https://github.com/chrislee87/go-quickselect
// this is modified to also return the highest 'last' (because it's not guaranteed it will be the first item)
func (its iterators) quickselect(k int) uint32 {
if len(its) == 0 || k >= len(its) {
return 0
}
left, right := 0, len(its)-1
for {
// insertion sort for small ranges
if right-left <= 20 {
for i := left + 1; i <= right; i++ {
for j := i; j > 0 && its[j].last > its[j-1].last; j-- {
its[j], its[j-1] = its[j-1], its[j]
}
}
return its[0].last
}
// median-of-three to choose pivot
pivotIndex := left + (right-left)/2
if its[right].last > its[left].last {
its[right], its[left] = its[left], its[right]
}
if its[pivotIndex].last > its[left].last {
its[pivotIndex], its[left] = its[left], its[pivotIndex]
}
if its[right].last > its[pivotIndex].last {
its[right], its[pivotIndex] = its[pivotIndex], its[right]
}
// partition
pivot := its[(right+left)/2].last
l := left
r := right
its[left], its[pivotIndex] = its[pivotIndex], its[left]
ll := left + 1
rr := right
for ll <= rr {
for ll <= right && its[ll].last > its[left].last {
ll++
}
for rr >= left && its[left].last > its[rr].last {
rr--
}
if ll <= rr {
its[ll], its[rr] = its[rr], its[ll]
ll++
rr--
}
}
its[left], its[rr] = its[rr], its[left] // swap into right place
pivotIndex = rr
for l <= r {
for its[l].last < pivot {
l++
if k == pivotIndex {
// now that stuff is selected we get the highest "last"
highest := its[0].last
for i := 1; i < k; i++ {
if its[i].last > highest {
highest = its[i].last
}
for its[r].last > pivot {
r--
}
if l >= r {
break
return highest
}
its[l].last, its[r].last = its[r].last, its[l].last
r--
l++
}
mid := r
// ~
if k > mid {
its.quickselect(mid+1, right, k)
if k < pivotIndex {
right = pivotIndex - 1
} else {
its.quickselect(left, mid, k)
left = pivotIndex + 1
}
}
}

View File

@@ -78,11 +78,9 @@ func (b *LMDBBackend) query(txn *lmdb.Txn, filter nostr.Filter, limit int, yield
// after pulling from all iterators once we now find out what iterators are
// the ones we should keep pulling from next (i.e. which one's last emitted timestamp is the highest)
iterators.quickselect(min(numberOfIteratorsToPullOnEachRound, len(iterators)), 0, len(iterators))
threshold := iterators.quickselect(min(numberOfIteratorsToPullOnEachRound, len(iterators)))
// we now know what is our threshold
threshold := iterators[len(iterators)-1].last
// so we can emit all the events higher than it
// so we can emit all the events higher than the threshold
for _, it := range iterators {
for t, ts := range it.timestamps {
if ts >= threshold {

View File

@@ -7,6 +7,7 @@ import (
)
func TestQuickselect(t *testing.T) {
{
its := iterators{
{last: 781},
{last: 900},
@@ -19,9 +20,48 @@ func TestQuickselect(t *testing.T) {
{last: 444},
}
its.quickselect(3, 0, len(its))
require.ElementsMatch(t, its[len(its)-3:], iterators{{last: 900}, {last: 781}, {last: 781}})
its.quickselect(3)
require.ElementsMatch(t,
[]uint32{its[0].last, its[1].last, its[2].last},
[]uint32{900, 781, 781},
)
}
its.quickselect(4, 0, len(its))
require.ElementsMatch(t, its[len(its)-4:], iterators{{last: 562}, {last: 900}, {last: 781}, {last: 781}})
{
its := iterators{
{last: 781},
{last: 781},
{last: 900},
{last: 1},
{last: 87},
{last: 315},
{last: 789},
{last: 500},
{last: 812},
{last: 306},
{last: 612},
{last: 444},
{last: 59},
{last: 441},
{last: 901},
{last: 901},
{last: 2},
{last: 81},
{last: 325},
{last: 781},
{last: 562},
{last: 81},
{last: 326},
{last: 662},
{last: 444},
{last: 81},
{last: 444},
}
its.quickselect(6)
require.ElementsMatch(t,
[]uint32{its[0].last, its[1].last, its[2].last, its[3].last, its[4].last, its[5].last},
[]uint32{901, 900, 901, 781, 812, 789},
)
}
}