diff --git a/envelopes.go b/envelopes.go index 3081537..5fcb78b 100644 --- a/envelopes.go +++ b/envelopes.go @@ -2,6 +2,7 @@ package nostr import ( "bytes" + "encoding/hex" "fmt" "strconv" @@ -141,7 +142,8 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) { type CountEnvelope struct { SubscriptionID string Filters - Count *int64 + Count *int64 + HyperLogLog []byte } func (_ CountEnvelope) Label() string { return "COUNT" } @@ -160,9 +162,11 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error { var countResult struct { Count *int64 `json:"count"` + HLL string `json:"hll"` } if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil { v.Count = countResult.Count + v.HyperLogLog, _ = hex.DecodeString(countResult.HLL) return nil } @@ -188,6 +192,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) { if v.Count != nil { w.RawString(`,{"count":`) w.RawString(strconv.FormatInt(*v.Count, 10)) + if v.HyperLogLog != nil { + w.RawString(`,"hll":"`) + hllHex := make([]byte, 512) + hex.Encode(hllHex, v.HyperLogLog) + w.Buffer.AppendBytes(hllHex) + w.RawString(`"`) + } w.RawString(`}`) } else { for _, filter := range v.Filters { diff --git a/nip45/hll_event.go b/nip45/hll_event.go new file mode 100644 index 0000000..dc4d01f --- /dev/null +++ b/nip45/hll_event.go @@ -0,0 +1,42 @@ +package nip45 + +import ( + "iter" + "strconv" + + "github.com/nbd-wtf/go-nostr" +) + +func HyperLogLogEventPubkeyOffsetsAndReferencesForEvent(evt *nostr.Event) iter.Seq2[string, int] { + return func(yield func(string, int) bool) { + switch evt.Kind { + case 3: + // + // follower counts + for _, tag := range evt.Tags { + if len(tag) >= 2 && tag[0] == "p" && nostr.IsValid32ByteHex(tag[1]) { + // 32th nibble of each "p" tag + p, _ := strconv.ParseInt(tag[1][32:33], 16, 64) + if !yield(tag[1], int(p+8)) { + return + } + } + } + case 7: + // + // reaction counts: + // (only the last "e" tag counts) + lastE := evt.Tags.GetLast([]string{"e", ""}) + if lastE != nil { + v := (*lastE)[1] + if nostr.IsValid32ByteHex(v) { + // 32th nibble of "e" tag + p, _ := strconv.ParseInt(v[32:33], 16, 64) + if !yield(v, int(p+8)) { + return + } + } + } + } + } +} diff --git a/nip45/hll_filter.go b/nip45/hll_filter.go new file mode 100644 index 0000000..c075b38 --- /dev/null +++ b/nip45/hll_filter.go @@ -0,0 +1,49 @@ +package nip45 + +import ( + "strconv" + + "github.com/nbd-wtf/go-nostr" +) + +// HyperLogLogEventPubkeyOffsetForFilter returns the deterministic pubkey offset that will be used +// when computing hyperloglogs in the context of a specific filter. +// +// It returns -1 when the filter is not eligible for hyperloglog calculation. +func HyperLogLogEventPubkeyOffsetForFilter(filter nostr.Filter) int { + if filter.IDs != nil || filter.Since != nil || filter.Until != nil || filter.Authors != nil || + len(filter.Kinds) != 1 || filter.Search != "" || len(filter.Tags) != 1 { + // obvious cases in which we won't bother to do hyperloglog stuff + return -1 + } + + // only serve the cases explicitly defined by the NIP: + if pTags, ok := filter.Tags["p"]; ok { + // + // follower counts: + if filter.Kinds[0] == 3 && len(pTags) == 1 { + // 32th nibble of "p" tag + p, err := strconv.ParseInt(pTags[0][32:33], 16, 64) + if err != nil { + return -1 + } + return int(p + 8) + } + } else if eTags, ok := filter.Tags["e"]; ok { + if len(eTags) == 1 { + // + // reaction counts: + if filter.Kinds[0] == 7 { + // 32th nibble of "e" tag + p, err := strconv.ParseInt(eTags[0][32:33], 16, 64) + if err != nil { + return -1 + } + return int(p + 8) + } + } + } + + // everything else is false at least for now + return -1 +} diff --git a/nip45/hyperloglog/helpers.go b/nip45/hyperloglog/helpers.go new file mode 100644 index 0000000..88aba53 --- /dev/null +++ b/nip45/hyperloglog/helpers.go @@ -0,0 +1,30 @@ +package hyperloglog + +import ( + "math" +) + +const two32 = 1 << 32 + +func linearCounting(m uint32, v uint32) float64 { + fm := float64(m) + return fm * math.Log(fm/float64(v)) +} + +func clz56(x uint64) uint8 { + var c uint8 + for m := uint64(1 << 55); m&x == 0 && m != 0; m >>= 1 { + c++ + } + return c +} + +func countZeros(s []uint8) uint32 { + var c uint32 + for _, v := range s { + if v == 0 { + c++ + } + } + return c +} diff --git a/nip45/hyperloglog/hll.go b/nip45/hyperloglog/hll.go new file mode 100644 index 0000000..e546b67 --- /dev/null +++ b/nip45/hyperloglog/hll.go @@ -0,0 +1,115 @@ +package hyperloglog + +import ( + "encoding/binary" + "encoding/hex" + "fmt" +) + +// Everything is hardcoded to use precision 8, i.e. 256 registers. +type HyperLogLog struct { + offset int + registers []uint8 +} + +func New(offset int) *HyperLogLog { + if offset < 0 || offset > 32-8 { + panic(fmt.Errorf("invalid offset %d", offset)) + } + + // precision is always 8 + // the number of registers is always 256 (1<<8) + hll := &HyperLogLog{offset: offset} + hll.registers = make([]uint8, 256) + return hll +} + +func NewWithRegisters(registers []byte, offset int) *HyperLogLog { + if offset < 0 || offset > 32-8 { + panic(fmt.Errorf("invalid offset %d", offset)) + } + if len(registers) != 256 { + panic(fmt.Errorf("invalid number of registers %d", len(registers))) + } + return &HyperLogLog{registers: registers, offset: offset} +} + +func (hll *HyperLogLog) GetRegisters() []byte { return hll.registers } +func (hll *HyperLogLog) SetRegisters(enc []byte) { hll.registers = enc } +func (hll *HyperLogLog) MergeRegisters(other []byte) { + for i, v := range other { + if v > hll.registers[i] { + hll.registers[i] = v + } + } +} + +func (hll *HyperLogLog) Clear() { + for i := range hll.registers { + hll.registers[i] = 0 + } +} + +// Add takes a Nostr event pubkey which will be used as the item "key" (that combined with the offset) +func (hll *HyperLogLog) Add(pubkey string) { + x, _ := hex.DecodeString(pubkey[hll.offset*2 : hll.offset*2+8*2]) + j := x[0] // register address (first 8 bits, i.e. first byte) + + w := binary.BigEndian.Uint64(x) // number that we will use + zeroBits := clz56(w) + 1 // count zeroes (skip the first byte, so only use 56 bits) + + if zeroBits > hll.registers[j] { + hll.registers[j] = zeroBits + } +} + +// AddBytes is like Add, but takes pubkey as bytes instead of as string +func (hll *HyperLogLog) AddBytes(pubkey []byte) { + x := pubkey[hll.offset : hll.offset+8] + j := x[0] // register address (first 8 bits, i.e. first byte) + + w := binary.BigEndian.Uint64(x) // number that we will use + zeroBits := clz56(w) + 1 // count zeroes (skip the first byte, so only use 56 bits) + + if zeroBits > hll.registers[j] { + hll.registers[j] = zeroBits + } +} + +func (hll *HyperLogLog) Merge(other *HyperLogLog) { + for i, v := range other.registers { + if v > hll.registers[i] { + hll.registers[i] = v + } + } +} + +func (hll *HyperLogLog) Count() uint64 { + v := countZeros(hll.registers) + + if v != 0 { + lc := linearCounting(256 /* nregisters */, v) + + if lc <= 220 /* threshold */ { + return uint64(lc) + } + } + + est := hll.calculateEstimate() + if est <= 256 /* nregisters */ *3 { + if v != 0 { + return uint64(linearCounting(256 /* nregisters */, v)) + } + } + + return uint64(est) +} + +func (hll HyperLogLog) calculateEstimate() float64 { + sum := 0.0 + for _, val := range hll.registers { + sum += 1.0 / float64(uint64(1)< %d)", c, count) + } +} + +func TestHyperLogLogMerge(t *testing.T) { + rand := rand.New(rand.NewPCG(2, 0)) + + for _, count := range []int{ + 2, 4, 6, 7, 12, 15, 22, 36, 44, 47, + 64, 77, 89, 95, 104, 116, 122, 144, + 150, 199, 300, 350, 400, 500, 600, + 777, 922, 1000, 1500, 2222, 9999, + 13600, 80000, 133333, 200000, + } { + hllA := New() + hllB := New() + + for range count / 2 { + b := make([]byte, 32) + for i := range b { + b[i] = uint8(rand.UintN(256)) + } + id := hex.EncodeToString(b) + hllA.Add(id) + } + for range count / 2 { + b := make([]byte, 32) + for i := range b { + b[i] = uint8(rand.UintN(256)) + } + id := hex.EncodeToString(b) + hllB.Add(id) + } + + hll := New() + hll.Merge(hllA) + hll.Merge(hllB) + + res100 := int(hll.Count() * 100) + require.Greater(t, res100, count*85, "result too low (actual %d < %d)", hll.Count(), count) + require.Less(t, res100, count*115, "result too high (actual %d > %d)", hll.Count(), count) + } +} + +func TestHyperLogLogMergeComplex(t *testing.T) { + rand := rand.New(rand.NewPCG(4, 0)) + + for _, count := range []int{ + 3, 6, 9, 12, 15, 22, 36, 46, 57, + 64, 77, 89, 95, 104, 116, 122, 144, + 150, 199, 300, 350, 400, 500, 600, + 777, 922, 1000, 1500, 2222, 9999, + 13600, 80000, 133333, 200000, + } { + hllA := New() + hllB := New() + hllC := New() + + for range count / 3 { + b := make([]byte, 32) + for i := range b { + b[i] = uint8(rand.UintN(256)) + } + id := hex.EncodeToString(b) + hllA.Add(id) + hllC.Add(id) + } + for range count / 3 { + b := make([]byte, 32) + for i := range b { + b[i] = uint8(rand.UintN(256)) + } + id := hex.EncodeToString(b) + hllB.Add(id) + hllC.Add(id) + } + for range count / 3 { + b := make([]byte, 32) + for i := range b { + b[i] = uint8(rand.UintN(256)) + } + id := hex.EncodeToString(b) + hllC.Add(id) + hllA.Add(id) + } + + hll := New() + hll.Merge(hllA) + hll.Merge(hllB) + hll.Merge(hllC) + + res100 := int(hll.Count() * 100) + require.Greater(t, res100, count*85, "result too low (actual %d < %d)", hll.Count(), count) + require.Less(t, res100, count*115, "result too high (actual %d > %d)", hll.Count(), count) + } +} diff --git a/pool.go b/pool.go index a2004f2..65e70c8 100644 --- a/pool.go +++ b/pool.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/nbd-wtf/go-nostr/nip45/hyperloglog" "github.com/puzpuzpuz/xsync/v3" ) @@ -468,6 +469,39 @@ func (pool *SimplePool) subManyEose( return events } +// CountMany aggregates count results from multiple relays using HyperLogLog +func (pool *SimplePool) CountMany( + ctx context.Context, + urls []string, + filter Filter, + opts []SubscriptionOption, +) int { + hll := hyperloglog.New(0) // offset is irrelevant here, so we just pass 0 + + wg := sync.WaitGroup{} + wg.Add(len(urls)) + for _, url := range urls { + go func(nm string) { + defer wg.Done() + relay, err := pool.EnsureRelay(url) + if err != nil { + return + } + ce, err := relay.countInternal(ctx, Filters{filter}, opts...) + if err != nil { + return + } + if len(ce.HyperLogLog) != 256 { + return + } + hll.MergeRegisters(ce.HyperLogLog) + }(NormalizeURL(url)) + } + + wg.Wait() + return int(hll.Count()) +} + // QuerySingle returns the first event returned by the first relay, cancels everything else. func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent { ctx, cancel := context.WithCancel(ctx) diff --git a/relay.go b/relay.go index 0319f24..8126422 100644 --- a/relay.go +++ b/relay.go @@ -275,7 +275,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error } case *CountEnvelope: if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil { - subscription.countResult <- *env.Count + subscription.countResult <- *env } case *OKEnvelope: if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { @@ -480,11 +480,19 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error) } func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) { + v, err := r.countInternal(ctx, filters, opts...) + if err != nil { + return 0, err + } + return *v.Count, nil +} + +func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) { sub := r.PrepareSubscription(ctx, filters, opts...) - sub.countResult = make(chan int64) + sub.countResult = make(chan CountEnvelope) if err := sub.Fire(); err != nil { - return 0, err + return CountEnvelope{}, err } defer sub.Unsub() @@ -501,7 +509,7 @@ func (r *Relay) Count(ctx context.Context, filters Filters, opts ...Subscription case count := <-sub.countResult: return count, nil case <-ctx.Done(): - return 0, ctx.Err() + return CountEnvelope{}, ctx.Err() } } } diff --git a/subscription.go b/subscription.go index d4c1af4..1adf62d 100644 --- a/subscription.go +++ b/subscription.go @@ -15,7 +15,7 @@ type Subscription struct { Filters Filters // for this to be treated as a COUNT and not a REQ this must be set - countResult chan int64 + countResult chan CountEnvelope // the Events channel emits all EVENTs that come in a Subscription // will be closed when the subscription ends @@ -152,7 +152,7 @@ func (sub *Subscription) Fire() error { if sub.countResult == nil { reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON() } else { - reqb, _ = CountEnvelope{sub.id, sub.Filters, nil}.MarshalJSON() + reqb, _ = CountEnvelope{sub.id, sub.Filters, nil, nil}.MarshalJSON() } sub.live.Store(true)