eventstore: SortedMerge() takes a limit and is simpler (should be faster) for small limits.
This commit is contained in:
@@ -2,15 +2,46 @@ package eventstore
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"iter"
|
"iter"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SortedMerge(it1, it2 iter.Seq[nostr.Event]) iter.Seq[nostr.Event] {
|
// SortedMerge combines two iterators and returns the top limit results aggregated from both.
|
||||||
|
// limit is implied to be also the maximum number of items each iterator will return.
|
||||||
|
func SortedMerge(it1, it2 iter.Seq[nostr.Event], limit int) iter.Seq[nostr.Event] {
|
||||||
|
if limit < 60 {
|
||||||
|
return func(yield func(nostr.Event) bool) {
|
||||||
|
acc := make([]nostr.Event, 0, limit*2)
|
||||||
|
for evt := range it1 {
|
||||||
|
acc = append(acc, evt)
|
||||||
|
}
|
||||||
|
for evt := range it2 {
|
||||||
|
acc = append(acc, evt)
|
||||||
|
}
|
||||||
|
slices.SortFunc(acc, nostr.CompareEventReverse)
|
||||||
|
for i := range min(limit, len(acc)) {
|
||||||
|
if !yield(acc[i]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
next1, done1 := iter.Pull(it1)
|
next1, done1 := iter.Pull(it1)
|
||||||
next2, done2 := iter.Pull(it2)
|
next2, done2 := iter.Pull(it2)
|
||||||
|
|
||||||
return func(yield func(nostr.Event) bool) {
|
return func(yieldInner func(nostr.Event) bool) {
|
||||||
|
count := 0
|
||||||
|
yield := func(evt nostr.Event) bool {
|
||||||
|
shouldContinue := yieldInner(evt)
|
||||||
|
count++
|
||||||
|
if count >= limit {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return shouldContinue
|
||||||
|
}
|
||||||
|
|
||||||
defer done1()
|
defer done1()
|
||||||
defer done2()
|
defer done2()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"fiatjaf.com/nostr"
|
"fiatjaf.com/nostr"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func FuzzSortedMerge(f *testing.F) {
|
func FuzzSortedMerge(f *testing.F) {
|
||||||
@@ -19,7 +20,7 @@ func FuzzSortedMerge(f *testing.F) {
|
|||||||
merged := SortedMerge(
|
merged := SortedMerge(
|
||||||
func(yield func(nostr.Event) bool) {
|
func(yield func(nostr.Event) bool) {
|
||||||
for range len1 {
|
for range len1 {
|
||||||
if !yield(nostr.Event{CreatedAt: nostr.Timestamp(start1)}) {
|
if !yield(nostr.Event{ID: nostr.ID(nostr.Generate()), CreatedAt: nostr.Timestamp(start1)}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
start1 -= uint(diff1)
|
start1 -= uint(diff1)
|
||||||
@@ -27,19 +28,18 @@ func FuzzSortedMerge(f *testing.F) {
|
|||||||
},
|
},
|
||||||
func(yield func(nostr.Event) bool) {
|
func(yield func(nostr.Event) bool) {
|
||||||
for range len2 {
|
for range len2 {
|
||||||
if !yield(nostr.Event{CreatedAt: nostr.Timestamp(start2)}) {
|
if !yield(nostr.Event{ID: nostr.ID(nostr.Generate()), CreatedAt: nostr.Timestamp(start2)}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
start2 -= uint(diff2)
|
start2 -= uint(diff2)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
int(max(len1, len2)),
|
||||||
)
|
)
|
||||||
result := slices.Collect(merged)
|
result := slices.Collect(merged)
|
||||||
|
|
||||||
// assert length
|
// assert length
|
||||||
if len(result) != int(len1+len2) {
|
require.Equal(t, int(max(len1, len2)), len(result), "got a different number of results than expected")
|
||||||
t.Fatalf("expected %d events, got %d", len1+len2, len(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert sorted descending
|
// assert sorted descending
|
||||||
slices.IsSortedFunc(result, func(a, b nostr.Event) int { return -1 * cmp.Compare(a.CreatedAt, b.CreatedAt) })
|
slices.IsSortedFunc(result, func(a, b nostr.Event) int { return -1 * cmp.Compare(a.CreatedAt, b.CreatedAt) })
|
||||||
|
|||||||
7
eventstore/testdata/fuzz/FuzzSortedMerge/0076b595fbac65cc
vendored
Normal file
7
eventstore/testdata/fuzz/FuzzSortedMerge/0076b595fbac65cc
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
go test fuzz v1
|
||||||
|
uint(52)
|
||||||
|
uint(16)
|
||||||
|
uint(56)
|
||||||
|
uint(7)
|
||||||
|
byte('#')
|
||||||
|
byte('\x00')
|
||||||
Reference in New Issue
Block a user