relay: auth handler; pool: rename auth handler.

This commit is contained in:
fiatjaf
2026-01-08 18:54:19 -03:00
parent 416e11b868
commit 3335c29389
2 changed files with 23 additions and 15 deletions

26
pool.go
View File

@@ -24,7 +24,7 @@ type Pool struct {
Relays *xsync.MapOf[string, *Relay] Relays *xsync.MapOf[string, *Relay]
Context context.Context Context context.Context
authHandler func(context.Context, *Event) error authRequiredHandler func(context.Context, *Event) error
cancel context.CancelCauseFunc cancel context.CancelCauseFunc
eventMiddleware func(RelayEvent) eventMiddleware func(RelayEvent)
@@ -59,7 +59,7 @@ func NewPool(opts PoolOptions) *Pool {
Context: ctx, Context: ctx,
cancel: cancel, cancel: cancel,
authHandler: opts.AuthHandler, authRequiredHandler: opts.AuthRequiredHandler,
eventMiddleware: opts.EventMiddleware, eventMiddleware: opts.EventMiddleware,
duplicateMiddleware: opts.DuplicateMiddleware, duplicateMiddleware: opts.DuplicateMiddleware,
queryMiddleware: opts.AuthorKindQueryMiddleware, queryMiddleware: opts.AuthorKindQueryMiddleware,
@@ -74,10 +74,10 @@ func NewPool(opts PoolOptions) *Pool {
} }
type PoolOptions struct { type PoolOptions struct {
// AuthHandler, if given, must be a function that signs the auth event when called. // AuthRequiredHandler, if given, must be a function that signs the auth event when called.
// it will be called whenever any relay in the pool returns a `CLOSED` message // it will be called whenever any relay in the pool returns a `CLOSED` or `OK` message
// with the "auth-required:" prefix, only once for each relay // with the "auth-required:" prefix, only once for each relay
AuthHandler func(context.Context, *Event) error AuthRequiredHandler func(context.Context, *Event) error
// PenaltyBox just sets the penalty box mechanism so relays that fail to connect // PenaltyBox just sets the penalty box mechanism so relays that fail to connect
// or that disconnect will be ignored for a while and we won't attempt to connect again. // or that disconnect will be ignored for a while and we won't attempt to connect again.
@@ -202,9 +202,9 @@ func (pool *Pool) PublishMany(ctx context.Context, urls []string, evt Event) cha
if err := relay.Publish(ctx, evt); err == nil { if err := relay.Publish(ctx, evt); err == nil {
// success with no auth required // success with no auth required
ch <- PublishResult{nil, url, relay} ch <- PublishResult{nil, url, relay}
} else if strings.HasPrefix(err.Error(), "msg: auth-required:") && pool.authHandler != nil { } else if strings.HasPrefix(err.Error(), "msg: auth-required:") && pool.authRequiredHandler != nil {
// try to authenticate if we can // try to authenticate if we can
if authErr := relay.Auth(ctx, pool.authHandler); authErr == nil { if authErr := relay.Auth(ctx, pool.authRequiredHandler); authErr == nil {
if err := relay.Publish(ctx, evt); err == nil { if err := relay.Publish(ctx, evt); err == nil {
// success after auth // success after auth
ch <- PublishResult{nil, url, relay} ch <- PublishResult{nil, url, relay}
@@ -389,9 +389,9 @@ func (pool *Pool) FetchManyReplaceable(
case <-sub.EndOfStoredEvents: case <-sub.EndOfStoredEvents:
return return
case reason := <-sub.ClosedReason: case reason := <-sub.ClosedReason:
if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed { if strings.HasPrefix(reason, "auth-required:") && pool.authRequiredHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again // relay is requesting auth. if we can we will perform auth and try again
err := relay.Auth(ctx, pool.authHandler) err := relay.Auth(ctx, pool.authRequiredHandler)
if err == nil { if err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again hasAuthed = true // so we don't keep doing AUTH again and again
goto subscribe goto subscribe
@@ -561,9 +561,9 @@ func (pool *Pool) subMany(
} }
} }
case reason := <-sub.ClosedReason: case reason := <-sub.ClosedReason:
if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed { if strings.HasPrefix(reason, "auth-required:") && pool.authRequiredHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again // relay is requesting auth. if we can we will perform auth and try again
err := relay.Auth(ctx, pool.authHandler) err := relay.Auth(ctx, pool.authRequiredHandler)
if err == nil { if err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again hasAuthed = true // so we don't keep doing AUTH again and again
if closedChan != nil { if closedChan != nil {
@@ -659,9 +659,9 @@ func (pool *Pool) subManyEose(
case <-sub.EndOfStoredEvents: case <-sub.EndOfStoredEvents:
return return
case reason := <-sub.ClosedReason: case reason := <-sub.ClosedReason:
if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed { if strings.HasPrefix(reason, "auth-required:") && pool.authRequiredHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again // relay is requesting auth. if we can we will perform auth and try again
err := relay.Auth(ctx, pool.authHandler) err := relay.Auth(ctx, pool.authRequiredHandler)
if err == nil { if err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again hasAuthed = true // so we don't keep doing AUTH again and again
if closedChan != nil { if closedChan != nil {

View File

@@ -40,6 +40,7 @@ type Relay struct {
connectionContextCancel context.CancelCauseFunc connectionContextCancel context.CancelCauseFunc
challenge string // NIP-42 challenge, we only keep the last challenge string // NIP-42 challenge, we only keep the last
authHandler func(context.Context, *Event) error
noticeHandler func(string) // NIP-01 NOTICEs noticeHandler func(string) // NIP-01 NOTICEs
customHandler func(string) // nonstandard unparseable messages customHandler func(string) // nonstandard unparseable messages
okCallbacks map[ID]okcallback okCallbacks map[ID]okcallback
@@ -64,6 +65,7 @@ func NewRelay(ctx context.Context, url string, opts RelayOptions) *Relay {
requestHeader: opts.RequestHeader, requestHeader: opts.RequestHeader,
customHandler: opts.CustomHandler, customHandler: opts.CustomHandler,
noticeHandler: opts.NoticeHandler, noticeHandler: opts.NoticeHandler,
authHandler: opts.AuthHandler,
} }
return r return r
@@ -82,6 +84,9 @@ func RelayConnect(ctx context.Context, url string, opts RelayOptions) (*Relay, e
} }
type RelayOptions struct { type RelayOptions struct {
// AuthHandler is fired when an AUTH message is received. It is given the AUTH event, unsigned, and expects you to sign it.
AuthHandler func(context.Context, *Event) error
// NoticeHandler just takes notices and is expected to do something with them. // NoticeHandler just takes notices and is expected to do something with them.
// When not given defaults to logging the notices. // When not given defaults to logging the notices.
NoticeHandler func(notice string) NoticeHandler func(notice string)
@@ -184,6 +189,9 @@ func (r *Relay) handleMessage(message string) {
return return
} }
r.challenge = *env.Challenge r.challenge = *env.Challenge
if r.authHandler != nil {
r.Auth(r.Context(), r.authHandler)
}
case *EventEnvelope: case *EventEnvelope:
// we already have the subscription from the pre-check above, so we can just reuse it // we already have the subscription from the pre-check above, so we can just reuse it
if sub == nil { if sub == nil {