nest okcallbacks so they're called one by one.

This commit is contained in:
fiatjaf
2025-08-06 15:13:55 -03:00
parent 960312bd74
commit 5d42b2f857

View File

@@ -35,7 +35,8 @@ type Relay struct {
challenge string // NIP-42 challenge, we only keep the last challenge string // NIP-42 challenge, we only keep the last
noticeHandler func(string) // NIP-01 NOTICEs noticeHandler func(string) // NIP-01 NOTICEs
customHandler func(string) // nonstandard unparseable messages customHandler func(string) // nonstandard unparseable messages
okCallbacks *xsync.MapOf[ID, func(bool, string)] okCallbacks map[ID]okcallback
okCallbacksMutex sync.Mutex
subscriptionChannelCloseQueue chan *Subscription subscriptionChannelCloseQueue chan *Subscription
// custom things that aren't often used // custom things that aren't often used
@@ -51,7 +52,7 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
connectionContext: ctx, connectionContext: ctx,
connectionContextCancel: cancel, connectionContextCancel: cancel,
Subscriptions: xsync.NewMapOf[int64, *Subscription](), Subscriptions: xsync.NewMapOf[int64, *Subscription](),
okCallbacks: xsync.NewMapOf[ID, func(bool, string)](), okCallbacks: make(map[ID]okcallback, 20),
subscriptionChannelCloseQueue: make(chan *Subscription), subscriptionChannelCloseQueue: make(chan *Subscription),
requestHeader: opts.RequestHeader, requestHeader: opts.RequestHeader,
} }
@@ -203,11 +204,13 @@ func (r *Relay) handleMessage(message string) {
subscription.countResult <- *env subscription.countResult <- *env
} }
case *OKEnvelope: case *OKEnvelope:
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { r.okCallbacksMutex.Lock()
if okCallback, exist := r.okCallbacks[env.EventID]; exist {
okCallback(env.OK, env.Reason) okCallback(env.OK, env.Reason)
} else { } else {
InfoLogger.Printf("{%s} got an unexpected OK message for event %s", r.URL, env.EventID) InfoLogger.Printf("{%s} got an unexpected OK message for event %s", r.URL, env.EventID)
} }
r.okCallbacksMutex.Unlock()
} }
} }
@@ -275,30 +278,34 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
} }
// listen for an OK callback // listen for an OK callback
gotOk := false gotOk := make(chan bool, 1)
handleOk := func(ok bool, reason string) { handleOk := func(ok bool, reason string) {
gotOk = true err = fmt.Errorf("msg: %s", reason)
if !ok { gotOk <- ok
err = fmt.Errorf("msg: %s", reason)
}
cancel()
} }
r.okCallbacks.Compute(id, func(oldValue func(bool, string), loaded bool) (newValue func(bool, string), delete bool) { r.okCallbacksMutex.Lock()
if !loaded { if previous, exists := r.okCallbacks[id]; !exists {
// normal path: there is nothing listening for this id, so we register this function // normal path: there is nothing listening for this id, so we register this function
return handleOk, false r.okCallbacks[id] = func(ok bool, reason string) {
handleOk(ok, reason)
// and when it's called the mutex will be locked
// so we just eliminate it
delete(r.okCallbacks, id)
} }
} else {
// if the same event is published twice there will be something here already // if the same event is published twice there will be something here already
// so we make a new handleOk() function that concatenates both // so we make a function that concatenates both
return func(ok bool, reason string) { r.okCallbacks[id] = func(ok bool, reason string) {
oldValue(ok, reason) // we call this with an informative helper for the developer
handleOk(ok, fmt.Sprintf("published twice: %s", reason)) // and we inform the developer handleOk(ok, fmt.Sprintf("published twice: %s", reason))
}, false
})
defer r.okCallbacks.Delete(id) // then we replace it with the previous (and when this is called it will nuke itself accordingly)
r.okCallbacks[id] = previous
}
}
r.okCallbacksMutex.Unlock()
// publish event // publish event
envb, _ := env.MarshalJSON() envb, _ := env.MarshalJSON()
@@ -308,14 +315,20 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
for { for {
select { select {
case <-ctx.Done(): case ok := <-gotOk:
// this will be called when we get an OK or when the context has been canceled if ok {
if gotOk { return nil
return err
} }
return err
case <-ctx.Done():
r.okCallbacksMutex.Lock()
if cb, _ := r.okCallbacks[id]; cb != nil {
cb(false, "timeout")
}
r.okCallbacksMutex.Unlock()
return fmt.Errorf("publish: %w", context.Cause(ctx)) return fmt.Errorf("publish: %w", context.Cause(ctx))
case <-r.connectionContext.Done(): case <-r.connectionContext.Done():
// this is caused when we lose connectivity r.okCallbacks = make(map[ID]okcallback)
return fmt.Errorf("relay: %w", context.Cause(r.connectionContext)) return fmt.Errorf("relay: %w", context.Cause(r.connectionContext))
} }
} }
@@ -479,3 +492,5 @@ func (r *Relay) close(reason error) error {
var subIdPool = sync.Pool{ var subIdPool = sync.Pool{
New: func() any { return make([]byte, 0, 15) }, New: func() any { return make([]byte, 0, 15) },
} }
type okcallback func(ok bool, reason string)