diff --git a/pool.go b/pool.go index 33a094f..9093b09 100644 --- a/pool.go +++ b/pool.go @@ -24,8 +24,8 @@ type Pool struct { Relays *xsync.MapOf[string, *Relay] Context context.Context - authHandler func(context.Context, *Event) error - cancel context.CancelCauseFunc + authRequiredHandler func(context.Context, *Event) error + cancel context.CancelCauseFunc eventMiddleware func(RelayEvent) duplicateMiddleware func(relay string, id ID) @@ -59,7 +59,7 @@ func NewPool(opts PoolOptions) *Pool { Context: ctx, cancel: cancel, - authHandler: opts.AuthHandler, + authRequiredHandler: opts.AuthRequiredHandler, eventMiddleware: opts.EventMiddleware, duplicateMiddleware: opts.DuplicateMiddleware, queryMiddleware: opts.AuthorKindQueryMiddleware, @@ -74,10 +74,10 @@ func NewPool(opts PoolOptions) *Pool { } type PoolOptions struct { - // AuthHandler, if given, must be a function that signs the auth event when called. - // it will be called whenever any relay in the pool returns a `CLOSED` message + // AuthRequiredHandler, if given, must be a function that signs the auth event when called. + // it will be called whenever any relay in the pool returns a `CLOSED` or `OK` message // with the "auth-required:" prefix, only once for each relay - AuthHandler func(context.Context, *Event) error + AuthRequiredHandler func(context.Context, *Event) error // PenaltyBox just sets the penalty box mechanism so relays that fail to connect // or that disconnect will be ignored for a while and we won't attempt to connect again. @@ -202,9 +202,9 @@ func (pool *Pool) PublishMany(ctx context.Context, urls []string, evt Event) cha if err := relay.Publish(ctx, evt); err == nil { // success with no auth required ch <- PublishResult{nil, url, relay} - } else if strings.HasPrefix(err.Error(), "msg: auth-required:") && pool.authHandler != nil { + } else if strings.HasPrefix(err.Error(), "msg: auth-required:") && pool.authRequiredHandler != nil { // try to authenticate if we can - if authErr := relay.Auth(ctx, pool.authHandler); authErr == nil { + if authErr := relay.Auth(ctx, pool.authRequiredHandler); authErr == nil { if err := relay.Publish(ctx, evt); err == nil { // success after auth ch <- PublishResult{nil, url, relay} @@ -389,9 +389,9 @@ func (pool *Pool) FetchManyReplaceable( case <-sub.EndOfStoredEvents: return case reason := <-sub.ClosedReason: - if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed { + if strings.HasPrefix(reason, "auth-required:") && pool.authRequiredHandler != nil && !hasAuthed { // relay is requesting auth. if we can we will perform auth and try again - err := relay.Auth(ctx, pool.authHandler) + err := relay.Auth(ctx, pool.authRequiredHandler) if err == nil { hasAuthed = true // so we don't keep doing AUTH again and again goto subscribe @@ -561,9 +561,9 @@ func (pool *Pool) subMany( } } case reason := <-sub.ClosedReason: - if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed { + if strings.HasPrefix(reason, "auth-required:") && pool.authRequiredHandler != nil && !hasAuthed { // relay is requesting auth. if we can we will perform auth and try again - err := relay.Auth(ctx, pool.authHandler) + err := relay.Auth(ctx, pool.authRequiredHandler) if err == nil { hasAuthed = true // so we don't keep doing AUTH again and again if closedChan != nil { @@ -659,9 +659,9 @@ func (pool *Pool) subManyEose( case <-sub.EndOfStoredEvents: return case reason := <-sub.ClosedReason: - if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed { + if strings.HasPrefix(reason, "auth-required:") && pool.authRequiredHandler != nil && !hasAuthed { // relay is requesting auth. if we can we will perform auth and try again - err := relay.Auth(ctx, pool.authHandler) + err := relay.Auth(ctx, pool.authRequiredHandler) if err == nil { hasAuthed = true // so we don't keep doing AUTH again and again if closedChan != nil { diff --git a/relay.go b/relay.go index c7e9383..aed6dbf 100644 --- a/relay.go +++ b/relay.go @@ -39,7 +39,8 @@ type Relay struct { connectionContext context.Context // will be canceled when the connection closes connectionContextCancel context.CancelCauseFunc - challenge string // NIP-42 challenge, we only keep the last + challenge string // NIP-42 challenge, we only keep the last + authHandler func(context.Context, *Event) error noticeHandler func(string) // NIP-01 NOTICEs customHandler func(string) // nonstandard unparseable messages okCallbacks map[ID]okcallback @@ -64,6 +65,7 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay { requestHeader: opts.RequestHeader, customHandler: opts.CustomHandler, noticeHandler: opts.NoticeHandler, + authHandler: opts.AuthHandler, } return r @@ -82,6 +84,9 @@ func RelayConnect(ctx context.Context, url string, opts RelayOptions) (*Relay, e } type RelayOptions struct { + // AuthHandler is fired when an AUTH message is received. It is given the AUTH event, unsigned, and expects you to sign it. + AuthHandler func(context.Context, *Event) error + // NoticeHandler just takes notices and is expected to do something with them. // When not given defaults to logging the notices. NoticeHandler func(notice string) @@ -184,6 +189,9 @@ func (r *Relay) handleMessage(message string) { return } r.challenge = *env.Challenge + if r.authHandler != nil { + r.Auth(r.Context(), r.authHandler) + } case *EventEnvelope: // we already have the subscription from the pre-check above, so we can just reuse it if sub == nil {