From f3ef256e651279a3f06baa0c4b9a98d3fea25d4a Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Thu, 3 Apr 2025 23:25:08 -0300 Subject: [PATCH] sdk: wot xor filter has a proper .Contains() function. --- sdk/wot.go | 62 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/sdk/wot.go b/sdk/wot.go index af5db2f..96f128d 100644 --- a/sdk/wot.go +++ b/sdk/wot.go @@ -8,37 +8,73 @@ import ( "github.com/FastFilter/xorfilter" "golang.org/x/sync/errgroup" + "sync" ) +func PubKeyToShid(pubkey string) uint64 { + shid, _ := strconv.ParseUint(pubkey[32:48], 16, 64) + return shid +} + func (sys *System) GetWoT(ctx context.Context, pubkey string) (map[uint64]struct{}, error) { g, ctx := errgroup.WithContext(ctx) + g.SetLimit(30) - res := make(chan uint64) + res := make(chan uint64, 100) // Add buffer to prevent blocking + result := make(map[uint64]struct{}) + var resultMu sync.Mutex // Add mutex to protect map access + + // Start consumer goroutine + done := make(chan struct{}) + go func() { + defer close(done) + for shid := range res { + resultMu.Lock() + result[shid] = struct{}{} + resultMu.Unlock() + } + }() + + // Process follow lists for _, f := range sys.FetchFollowList(ctx, pubkey).Items { + f := f // Capture loop variable g.Go(func() error { for _, f2 := range sys.FetchFollowList(ctx, f.Pubkey).Items { - shid, _ := strconv.ParseUint(f2.Pubkey[32:48], 16, 64) - res <- shid + select { + case res <- PubKeyToShid(f2.Pubkey): + case <-ctx.Done(): + return ctx.Err() + } } return nil }) } - result := make(map[uint64]struct{}) - go func() { - for shid := range res { - result[shid] = struct{}{} - } - }() + err := g.Wait() + close(res) // Close channel after all goroutines are done + <-done // Wait for consumer to finish - return result, g.Wait() + return result, err } -func (sys *System) GetWoTFilter(ctx context.Context, pubkey string) (*xorfilter.Xor8, error) { +func (sys *System) GetWoTFilter(ctx context.Context, pubkey string) (WotXorFilter, error) { m, err := sys.GetWoT(ctx, pubkey) if err != nil { - return nil, err + return WotXorFilter{}, err } - return xorfilter.Populate(slices.Collect(maps.Keys(m))) + xf, err := xorfilter.Populate(slices.Collect(maps.Keys(m))) + if err != nil { + return WotXorFilter{}, err + } + + return WotXorFilter{*xf}, nil +} + +type WotXorFilter struct { + xorfilter.Xor8 +} + +func (wxf WotXorFilter) Contains(pubkey string) bool { + return wxf.Xor8.Contains(PubKeyToShid(pubkey)) }