Merge branch 'hyperloglog'
This commit is contained in:
11
envelopes.go
11
envelopes.go
@@ -2,6 +2,7 @@ package nostr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -142,6 +143,7 @@ type CountEnvelope struct {
|
|||||||
SubscriptionID string
|
SubscriptionID string
|
||||||
Filters
|
Filters
|
||||||
Count *int64
|
Count *int64
|
||||||
|
HyperLogLog []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (_ CountEnvelope) Label() string { return "COUNT" }
|
func (_ CountEnvelope) Label() string { return "COUNT" }
|
||||||
@@ -160,9 +162,11 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {
|
|||||||
|
|
||||||
var countResult struct {
|
var countResult struct {
|
||||||
Count *int64 `json:"count"`
|
Count *int64 `json:"count"`
|
||||||
|
HLL string `json:"hll"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
|
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
|
||||||
v.Count = countResult.Count
|
v.Count = countResult.Count
|
||||||
|
v.HyperLogLog, _ = hex.DecodeString(countResult.HLL)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,6 +192,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
|
|||||||
if v.Count != nil {
|
if v.Count != nil {
|
||||||
w.RawString(`,{"count":`)
|
w.RawString(`,{"count":`)
|
||||||
w.RawString(strconv.FormatInt(*v.Count, 10))
|
w.RawString(strconv.FormatInt(*v.Count, 10))
|
||||||
|
if v.HyperLogLog != nil {
|
||||||
|
w.RawString(`,"hll":"`)
|
||||||
|
hllHex := make([]byte, 512)
|
||||||
|
hex.Encode(hllHex, v.HyperLogLog)
|
||||||
|
w.Buffer.AppendBytes(hllHex)
|
||||||
|
w.RawString(`"`)
|
||||||
|
}
|
||||||
w.RawString(`}`)
|
w.RawString(`}`)
|
||||||
} else {
|
} else {
|
||||||
for _, filter := range v.Filters {
|
for _, filter := range v.Filters {
|
||||||
|
|||||||
42
nip45/hll_event.go
Normal file
42
nip45/hll_event.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package nip45
|
||||||
|
|
||||||
|
import (
|
||||||
|
"iter"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/nbd-wtf/go-nostr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HyperLogLogEventPubkeyOffsetsAndReferencesForEvent(evt *nostr.Event) iter.Seq2[string, int] {
|
||||||
|
return func(yield func(string, int) bool) {
|
||||||
|
switch evt.Kind {
|
||||||
|
case 3:
|
||||||
|
//
|
||||||
|
// follower counts
|
||||||
|
for _, tag := range evt.Tags {
|
||||||
|
if len(tag) >= 2 && tag[0] == "p" && nostr.IsValid32ByteHex(tag[1]) {
|
||||||
|
// 32th nibble of each "p" tag
|
||||||
|
p, _ := strconv.ParseInt(tag[1][32:33], 16, 64)
|
||||||
|
if !yield(tag[1], int(p+8)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 7:
|
||||||
|
//
|
||||||
|
// reaction counts:
|
||||||
|
// (only the last "e" tag counts)
|
||||||
|
lastE := evt.Tags.GetLast([]string{"e", ""})
|
||||||
|
if lastE != nil {
|
||||||
|
v := (*lastE)[1]
|
||||||
|
if nostr.IsValid32ByteHex(v) {
|
||||||
|
// 32th nibble of "e" tag
|
||||||
|
p, _ := strconv.ParseInt(v[32:33], 16, 64)
|
||||||
|
if !yield(v, int(p+8)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
49
nip45/hll_filter.go
Normal file
49
nip45/hll_filter.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package nip45
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/nbd-wtf/go-nostr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HyperLogLogEventPubkeyOffsetForFilter returns the deterministic pubkey offset that will be used
|
||||||
|
// when computing hyperloglogs in the context of a specific filter.
|
||||||
|
//
|
||||||
|
// It returns -1 when the filter is not eligible for hyperloglog calculation.
|
||||||
|
func HyperLogLogEventPubkeyOffsetForFilter(filter nostr.Filter) int {
|
||||||
|
if filter.IDs != nil || filter.Since != nil || filter.Until != nil || filter.Authors != nil ||
|
||||||
|
len(filter.Kinds) != 1 || filter.Search != "" || len(filter.Tags) != 1 {
|
||||||
|
// obvious cases in which we won't bother to do hyperloglog stuff
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// only serve the cases explicitly defined by the NIP:
|
||||||
|
if pTags, ok := filter.Tags["p"]; ok {
|
||||||
|
//
|
||||||
|
// follower counts:
|
||||||
|
if filter.Kinds[0] == 3 && len(pTags) == 1 {
|
||||||
|
// 32th nibble of "p" tag
|
||||||
|
p, err := strconv.ParseInt(pTags[0][32:33], 16, 64)
|
||||||
|
if err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return int(p + 8)
|
||||||
|
}
|
||||||
|
} else if eTags, ok := filter.Tags["e"]; ok {
|
||||||
|
if len(eTags) == 1 {
|
||||||
|
//
|
||||||
|
// reaction counts:
|
||||||
|
if filter.Kinds[0] == 7 {
|
||||||
|
// 32th nibble of "e" tag
|
||||||
|
p, err := strconv.ParseInt(eTags[0][32:33], 16, 64)
|
||||||
|
if err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return int(p + 8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// everything else is false at least for now
|
||||||
|
return -1
|
||||||
|
}
|
||||||
30
nip45/hyperloglog/helpers.go
Normal file
30
nip45/hyperloglog/helpers.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package hyperloglog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
const two32 = 1 << 32
|
||||||
|
|
||||||
|
func linearCounting(m uint32, v uint32) float64 {
|
||||||
|
fm := float64(m)
|
||||||
|
return fm * math.Log(fm/float64(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
func clz56(x uint64) uint8 {
|
||||||
|
var c uint8
|
||||||
|
for m := uint64(1 << 55); m&x == 0 && m != 0; m >>= 1 {
|
||||||
|
c++
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func countZeros(s []uint8) uint32 {
|
||||||
|
var c uint32
|
||||||
|
for _, v := range s {
|
||||||
|
if v == 0 {
|
||||||
|
c++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
115
nip45/hyperloglog/hll.go
Normal file
115
nip45/hyperloglog/hll.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package hyperloglog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Everything is hardcoded to use precision 8, i.e. 256 registers.
|
||||||
|
type HyperLogLog struct {
|
||||||
|
offset int
|
||||||
|
registers []uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(offset int) *HyperLogLog {
|
||||||
|
if offset < 0 || offset > 32-8 {
|
||||||
|
panic(fmt.Errorf("invalid offset %d", offset))
|
||||||
|
}
|
||||||
|
|
||||||
|
// precision is always 8
|
||||||
|
// the number of registers is always 256 (1<<8)
|
||||||
|
hll := &HyperLogLog{offset: offset}
|
||||||
|
hll.registers = make([]uint8, 256)
|
||||||
|
return hll
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithRegisters(registers []byte, offset int) *HyperLogLog {
|
||||||
|
if offset < 0 || offset > 32-8 {
|
||||||
|
panic(fmt.Errorf("invalid offset %d", offset))
|
||||||
|
}
|
||||||
|
if len(registers) != 256 {
|
||||||
|
panic(fmt.Errorf("invalid number of registers %d", len(registers)))
|
||||||
|
}
|
||||||
|
return &HyperLogLog{registers: registers, offset: offset}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hll *HyperLogLog) GetRegisters() []byte { return hll.registers }
|
||||||
|
func (hll *HyperLogLog) SetRegisters(enc []byte) { hll.registers = enc }
|
||||||
|
func (hll *HyperLogLog) MergeRegisters(other []byte) {
|
||||||
|
for i, v := range other {
|
||||||
|
if v > hll.registers[i] {
|
||||||
|
hll.registers[i] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hll *HyperLogLog) Clear() {
|
||||||
|
for i := range hll.registers {
|
||||||
|
hll.registers[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add takes a Nostr event pubkey which will be used as the item "key" (that combined with the offset)
|
||||||
|
func (hll *HyperLogLog) Add(pubkey string) {
|
||||||
|
x, _ := hex.DecodeString(pubkey[hll.offset*2 : hll.offset*2+8*2])
|
||||||
|
j := x[0] // register address (first 8 bits, i.e. first byte)
|
||||||
|
|
||||||
|
w := binary.BigEndian.Uint64(x) // number that we will use
|
||||||
|
zeroBits := clz56(w) + 1 // count zeroes (skip the first byte, so only use 56 bits)
|
||||||
|
|
||||||
|
if zeroBits > hll.registers[j] {
|
||||||
|
hll.registers[j] = zeroBits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBytes is like Add, but takes pubkey as bytes instead of as string
|
||||||
|
func (hll *HyperLogLog) AddBytes(pubkey []byte) {
|
||||||
|
x := pubkey[hll.offset : hll.offset+8]
|
||||||
|
j := x[0] // register address (first 8 bits, i.e. first byte)
|
||||||
|
|
||||||
|
w := binary.BigEndian.Uint64(x) // number that we will use
|
||||||
|
zeroBits := clz56(w) + 1 // count zeroes (skip the first byte, so only use 56 bits)
|
||||||
|
|
||||||
|
if zeroBits > hll.registers[j] {
|
||||||
|
hll.registers[j] = zeroBits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hll *HyperLogLog) Merge(other *HyperLogLog) {
|
||||||
|
for i, v := range other.registers {
|
||||||
|
if v > hll.registers[i] {
|
||||||
|
hll.registers[i] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hll *HyperLogLog) Count() uint64 {
|
||||||
|
v := countZeros(hll.registers)
|
||||||
|
|
||||||
|
if v != 0 {
|
||||||
|
lc := linearCounting(256 /* nregisters */, v)
|
||||||
|
|
||||||
|
if lc <= 220 /* threshold */ {
|
||||||
|
return uint64(lc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
est := hll.calculateEstimate()
|
||||||
|
if est <= 256 /* nregisters */ *3 {
|
||||||
|
if v != 0 {
|
||||||
|
return uint64(linearCounting(256 /* nregisters */, v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint64(est)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hll HyperLogLog) calculateEstimate() float64 {
|
||||||
|
sum := 0.0
|
||||||
|
for _, val := range hll.registers {
|
||||||
|
sum += 1.0 / float64(uint64(1)<<val) // this is the same as 2^(-val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0.7182725932495458 /* alpha for 256 registers */ * 256 /* nregisters */ * 256 /* nregisters */ / sum
|
||||||
|
}
|
||||||
130
nip45/hyperloglog/hll_test.go
Normal file
130
nip45/hyperloglog/hll_test.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package hyperloglog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"math/rand/v2"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHyperLogLogBasic(t *testing.T) {
|
||||||
|
rand := rand.New(rand.NewPCG(1, 0))
|
||||||
|
|
||||||
|
for _, count := range []int{
|
||||||
|
2, 4, 6, 7, 12, 15, 22, 36, 44, 47,
|
||||||
|
64, 77, 89, 95, 104, 116, 122, 144,
|
||||||
|
150, 199, 300, 350, 400, 500, 600,
|
||||||
|
777, 922, 1000, 1500, 2222, 9999,
|
||||||
|
13600, 80000, 133333, 200000,
|
||||||
|
} {
|
||||||
|
hll := New()
|
||||||
|
|
||||||
|
for range count {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = uint8(rand.UintN(256))
|
||||||
|
}
|
||||||
|
id := hex.EncodeToString(b)
|
||||||
|
hll.Add(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := hll.Count()
|
||||||
|
res100 := int(c * 100)
|
||||||
|
require.Greater(t, res100, count*85, "result too low (actual %d < %d)", c, count)
|
||||||
|
require.Less(t, res100, count*115, "result too high (actual %d > %d)", c, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHyperLogLogMerge(t *testing.T) {
|
||||||
|
rand := rand.New(rand.NewPCG(2, 0))
|
||||||
|
|
||||||
|
for _, count := range []int{
|
||||||
|
2, 4, 6, 7, 12, 15, 22, 36, 44, 47,
|
||||||
|
64, 77, 89, 95, 104, 116, 122, 144,
|
||||||
|
150, 199, 300, 350, 400, 500, 600,
|
||||||
|
777, 922, 1000, 1500, 2222, 9999,
|
||||||
|
13600, 80000, 133333, 200000,
|
||||||
|
} {
|
||||||
|
hllA := New()
|
||||||
|
hllB := New()
|
||||||
|
|
||||||
|
for range count / 2 {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = uint8(rand.UintN(256))
|
||||||
|
}
|
||||||
|
id := hex.EncodeToString(b)
|
||||||
|
hllA.Add(id)
|
||||||
|
}
|
||||||
|
for range count / 2 {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = uint8(rand.UintN(256))
|
||||||
|
}
|
||||||
|
id := hex.EncodeToString(b)
|
||||||
|
hllB.Add(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
hll := New()
|
||||||
|
hll.Merge(hllA)
|
||||||
|
hll.Merge(hllB)
|
||||||
|
|
||||||
|
res100 := int(hll.Count() * 100)
|
||||||
|
require.Greater(t, res100, count*85, "result too low (actual %d < %d)", hll.Count(), count)
|
||||||
|
require.Less(t, res100, count*115, "result too high (actual %d > %d)", hll.Count(), count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHyperLogLogMergeComplex(t *testing.T) {
|
||||||
|
rand := rand.New(rand.NewPCG(4, 0))
|
||||||
|
|
||||||
|
for _, count := range []int{
|
||||||
|
3, 6, 9, 12, 15, 22, 36, 46, 57,
|
||||||
|
64, 77, 89, 95, 104, 116, 122, 144,
|
||||||
|
150, 199, 300, 350, 400, 500, 600,
|
||||||
|
777, 922, 1000, 1500, 2222, 9999,
|
||||||
|
13600, 80000, 133333, 200000,
|
||||||
|
} {
|
||||||
|
hllA := New()
|
||||||
|
hllB := New()
|
||||||
|
hllC := New()
|
||||||
|
|
||||||
|
for range count / 3 {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = uint8(rand.UintN(256))
|
||||||
|
}
|
||||||
|
id := hex.EncodeToString(b)
|
||||||
|
hllA.Add(id)
|
||||||
|
hllC.Add(id)
|
||||||
|
}
|
||||||
|
for range count / 3 {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = uint8(rand.UintN(256))
|
||||||
|
}
|
||||||
|
id := hex.EncodeToString(b)
|
||||||
|
hllB.Add(id)
|
||||||
|
hllC.Add(id)
|
||||||
|
}
|
||||||
|
for range count / 3 {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = uint8(rand.UintN(256))
|
||||||
|
}
|
||||||
|
id := hex.EncodeToString(b)
|
||||||
|
hllC.Add(id)
|
||||||
|
hllA.Add(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
hll := New()
|
||||||
|
hll.Merge(hllA)
|
||||||
|
hll.Merge(hllB)
|
||||||
|
hll.Merge(hllC)
|
||||||
|
|
||||||
|
res100 := int(hll.Count() * 100)
|
||||||
|
require.Greater(t, res100, count*85, "result too low (actual %d < %d)", hll.Count(), count)
|
||||||
|
require.Less(t, res100, count*115, "result too high (actual %d > %d)", hll.Count(), count)
|
||||||
|
}
|
||||||
|
}
|
||||||
34
pool.go
34
pool.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/nbd-wtf/go-nostr/nip45/hyperloglog"
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -468,6 +469,39 @@ func (pool *SimplePool) subManyEose(
|
|||||||
return events
|
return events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountMany aggregates count results from multiple relays using HyperLogLog
|
||||||
|
func (pool *SimplePool) CountMany(
|
||||||
|
ctx context.Context,
|
||||||
|
urls []string,
|
||||||
|
filter Filter,
|
||||||
|
opts []SubscriptionOption,
|
||||||
|
) int {
|
||||||
|
hll := hyperloglog.New(0) // offset is irrelevant here, so we just pass 0
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(len(urls))
|
||||||
|
for _, url := range urls {
|
||||||
|
go func(nm string) {
|
||||||
|
defer wg.Done()
|
||||||
|
relay, err := pool.EnsureRelay(url)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ce, err := relay.countInternal(ctx, Filters{filter}, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ce.HyperLogLog) != 256 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hll.MergeRegisters(ce.HyperLogLog)
|
||||||
|
}(NormalizeURL(url))
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
return int(hll.Count())
|
||||||
|
}
|
||||||
|
|
||||||
// QuerySingle returns the first event returned by the first relay, cancels everything else.
|
// QuerySingle returns the first event returned by the first relay, cancels everything else.
|
||||||
func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent {
|
func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|||||||
16
relay.go
16
relay.go
@@ -275,7 +275,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
|
|||||||
}
|
}
|
||||||
case *CountEnvelope:
|
case *CountEnvelope:
|
||||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
|
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
|
||||||
subscription.countResult <- *env.Count
|
subscription.countResult <- *env
|
||||||
}
|
}
|
||||||
case *OKEnvelope:
|
case *OKEnvelope:
|
||||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||||
@@ -480,11 +480,19 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
|
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
|
||||||
|
v, err := r.countInternal(ctx, filters, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return *v.Count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
|
||||||
sub := r.PrepareSubscription(ctx, filters, opts...)
|
sub := r.PrepareSubscription(ctx, filters, opts...)
|
||||||
sub.countResult = make(chan int64)
|
sub.countResult = make(chan CountEnvelope)
|
||||||
|
|
||||||
if err := sub.Fire(); err != nil {
|
if err := sub.Fire(); err != nil {
|
||||||
return 0, err
|
return CountEnvelope{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer sub.Unsub()
|
defer sub.Unsub()
|
||||||
@@ -501,7 +509,7 @@ func (r *Relay) Count(ctx context.Context, filters Filters, opts ...Subscription
|
|||||||
case count := <-sub.countResult:
|
case count := <-sub.countResult:
|
||||||
return count, nil
|
return count, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return 0, ctx.Err()
|
return CountEnvelope{}, ctx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ type Subscription struct {
|
|||||||
Filters Filters
|
Filters Filters
|
||||||
|
|
||||||
// for this to be treated as a COUNT and not a REQ this must be set
|
// for this to be treated as a COUNT and not a REQ this must be set
|
||||||
countResult chan int64
|
countResult chan CountEnvelope
|
||||||
|
|
||||||
// the Events channel emits all EVENTs that come in a Subscription
|
// the Events channel emits all EVENTs that come in a Subscription
|
||||||
// will be closed when the subscription ends
|
// will be closed when the subscription ends
|
||||||
@@ -152,7 +152,7 @@ func (sub *Subscription) Fire() error {
|
|||||||
if sub.countResult == nil {
|
if sub.countResult == nil {
|
||||||
reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON()
|
reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON()
|
||||||
} else {
|
} else {
|
||||||
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil}.MarshalJSON()
|
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil, nil}.MarshalJSON()
|
||||||
}
|
}
|
||||||
|
|
||||||
sub.live.Store(true)
|
sub.live.Store(true)
|
||||||
|
|||||||
Reference in New Issue
Block a user