From 5d42b2f857e6ccd2801f0e126032d36946a430a5 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Wed, 6 Aug 2025 15:13:55 -0300 Subject: [PATCH] nest okcallbacks so they're called one by one. --- relay.go | 67 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/relay.go b/relay.go index f1b10e3..13ad3e8 100644 --- a/relay.go +++ b/relay.go @@ -35,7 +35,8 @@ type Relay struct { challenge string // NIP-42 challenge, we only keep the last noticeHandler func(string) // NIP-01 NOTICEs customHandler func(string) // nonstandard unparseable messages - okCallbacks *xsync.MapOf[ID, func(bool, string)] + okCallbacks map[ID]okcallback + okCallbacksMutex sync.Mutex subscriptionChannelCloseQueue chan *Subscription // custom things that aren't often used @@ -51,7 +52,7 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay { connectionContext: ctx, connectionContextCancel: cancel, Subscriptions: xsync.NewMapOf[int64, *Subscription](), - okCallbacks: xsync.NewMapOf[ID, func(bool, string)](), + okCallbacks: make(map[ID]okcallback, 20), subscriptionChannelCloseQueue: make(chan *Subscription), requestHeader: opts.RequestHeader, } @@ -203,11 +204,13 @@ func (r *Relay) handleMessage(message string) { subscription.countResult <- *env } case *OKEnvelope: - if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { + r.okCallbacksMutex.Lock() + if okCallback, exist := r.okCallbacks[env.EventID]; exist { okCallback(env.OK, env.Reason) } else { InfoLogger.Printf("{%s} got an unexpected OK message for event %s", r.URL, env.EventID) } + r.okCallbacksMutex.Unlock() } } @@ -275,30 +278,34 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error { } // listen for an OK callback - gotOk := false + gotOk := make(chan bool, 1) handleOk := func(ok bool, reason string) { - gotOk = true - if !ok { - err = fmt.Errorf("msg: %s", reason) - } - cancel() + err = fmt.Errorf("msg: %s", reason) + gotOk <- ok } - r.okCallbacks.Compute(id, func(oldValue func(bool, string), loaded bool) (newValue func(bool, string), delete bool) { - if !loaded { - // normal path: there is nothing listening for this id, so we register this function - return handleOk, false + r.okCallbacksMutex.Lock() + if previous, exists := r.okCallbacks[id]; !exists { + // normal path: there is nothing listening for this id, so we register this function + r.okCallbacks[id] = func(ok bool, reason string) { + handleOk(ok, reason) + + // and when it's called the mutex will be locked + // so we just eliminate it + delete(r.okCallbacks, id) } - + } else { // if the same event is published twice there will be something here already - // so we make a new handleOk() function that concatenates both - return func(ok bool, reason string) { - oldValue(ok, reason) - handleOk(ok, fmt.Sprintf("published twice: %s", reason)) // and we inform the developer - }, false - }) + // so we make a function that concatenates both + r.okCallbacks[id] = func(ok bool, reason string) { + // we call this with an informative helper for the developer + handleOk(ok, fmt.Sprintf("published twice: %s", reason)) - defer r.okCallbacks.Delete(id) + // then we replace it with the previous (and when this is called it will nuke itself accordingly) + r.okCallbacks[id] = previous + } + } + r.okCallbacksMutex.Unlock() // publish event envb, _ := env.MarshalJSON() @@ -308,14 +315,20 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error { for { select { - case <-ctx.Done(): - // this will be called when we get an OK or when the context has been canceled - if gotOk { - return err + case ok := <-gotOk: + if ok { + return nil } + return err + case <-ctx.Done(): + r.okCallbacksMutex.Lock() + if cb, _ := r.okCallbacks[id]; cb != nil { + cb(false, "timeout") + } + r.okCallbacksMutex.Unlock() return fmt.Errorf("publish: %w", context.Cause(ctx)) case <-r.connectionContext.Done(): - // this is caused when we lose connectivity + r.okCallbacks = make(map[ID]okcallback) return fmt.Errorf("relay: %w", context.Cause(r.connectionContext)) } } @@ -479,3 +492,5 @@ func (r *Relay) close(reason error) error { var subIdPool = sync.Pool{ New: func() any { return make([]byte, 0, 15) }, } + +type okcallback func(ok bool, reason string)