*connection to be an integral part of *Relay.

This commit is contained in:
fiatjaf
2025-11-12 06:54:41 -03:00
parent 1c43f0d666
commit c2ab9d082c
2 changed files with 54 additions and 71 deletions

View File

@@ -7,8 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -17,30 +15,13 @@ import (
var ErrDisconnected = errors.New("<disconnected>") var ErrDisconnected = errors.New("<disconnected>")
// Connection represents a websocket connection to a Nostr relay.
type connection struct {
conn *ws.Conn
cancel context.CancelCauseFunc
writeQueue chan writeRequest
closed *atomic.Bool
closedNotify chan struct{}
closeMutex sync.Mutex
}
type writeRequest struct { type writeRequest struct {
msg []byte msg []byte
answer chan error answer chan error
} }
func newConnection( func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error {
ctx context.Context, debugLogf("{%s} connecting!\n", r.URL)
cancel context.CancelCauseFunc,
url string,
handleMessage func(string),
requestHeader http.Header,
tlsConfig *tls.Config,
) (*connection, error) {
debugLogf("{%s} connecting!\n", url)
dialCtx := ctx dialCtx := ctx
if _, ok := dialCtx.Deadline(); !ok { if _, ok := dialCtx.Deadline(); !ok {
@@ -48,9 +29,9 @@ func newConnection(
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
} }
c, _, err := ws.Dial(dialCtx, url, getConnectionOptions(requestHeader, tlsConfig)) c, _, err := ws.Dial(dialCtx, r.URL, getConnectionOptions(r.requestHeader, tlsConfig))
if err != nil { if err != nil {
return nil, err return err
} }
c.SetReadLimit(2 << 24) // 33MB c.SetReadLimit(2 << 24) // 33MB
@@ -63,40 +44,37 @@ func newConnection(
writeQueue := make(chan writeRequest) writeQueue := make(chan writeRequest)
readQueue := make(chan string) readQueue := make(chan string)
conn := &connection{ r.conn = c
conn: c, r.writeQueue = writeQueue
cancel: cancel, r.closed = &atomic.Bool{}
writeQueue: writeQueue, r.closedNotify = make(chan struct{})
closed: &atomic.Bool{},
closedNotify: make(chan struct{}),
}
go func() { go func() {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
conn.doClose(ws.StatusNormalClosure, "") r.closeConnection(ws.StatusNormalClosure, "")
debugLogf("{%s} closing!, context done: '%s'\n", url, context.Cause(ctx)) debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
return return
case <-conn.closedNotify: case <-r.closedNotify:
return return
case <-ticker.C: case <-ticker.C:
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long")) ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
err := c.Ping(ctx) err := c.Ping(ctx)
cancel() cancel()
if err != nil { if err != nil {
debugLogf("{%s} closing!, ping failed: '%s'\n", url, err) debugLogf("{%s} closing!, ping failed: '%s'\n", r.URL, err)
conn.doClose(ws.StatusAbnormalClosure, "ping took too long") r.closeConnection(ws.StatusAbnormalClosure, "ping took too long")
return return
} }
case wr := <-writeQueue: case wr := <-writeQueue:
debugLogf("{%s} sending '%v'\n", url, string(wr.msg)) debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long")) ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
err := c.Write(ctx, ws.MessageText, wr.msg) err := c.Write(ctx, ws.MessageText, wr.msg)
cancel() cancel()
if err != nil { if err != nil {
debugLogf("{%s} closing!, write failed: '%s'\n", url, err) debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
conn.doClose(ws.StatusAbnormalClosure, "write failed") r.closeConnection(ws.StatusAbnormalClosure, "write failed")
if wr.answer != nil { if wr.answer != nil {
wr.answer <- err wr.answer <- err
} }
@@ -106,8 +84,8 @@ func newConnection(
close(wr.answer) close(wr.answer)
} }
case msg := <-readQueue: case msg := <-readQueue:
debugLogf("{%s} received %v\n", url, msg) debugLogf("{%s} received %v\n", r.URL, msg)
handleMessage(msg) r.handleMessage(msg)
} }
} }
}() }()
@@ -121,13 +99,13 @@ func newConnection(
_, reader, err := c.Reader(ctx) _, reader, err := c.Reader(ctx)
if err != nil { if err != nil {
debugLogf("{%s} closing!, reader failure: '%s'\n", url, err) debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
conn.doClose(ws.StatusAbnormalClosure, "failed to get reader") r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
return return
} }
if _, err := io.Copy(buf, reader); err != nil { if _, err := io.Copy(buf, reader); err != nil {
debugLogf("{%s} closing!, read failure: '%s'\n", url, err) debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
conn.doClose(ws.StatusAbnormalClosure, "failed to read") r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
return return
} }
@@ -135,17 +113,18 @@ func newConnection(
} }
}() }()
return conn, nil return nil
} }
func (c *connection) doClose(code ws.StatusCode, reason string) { func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
wasClosed := c.closed.Swap(true) wasClosed := r.closed.Swap(true)
if !wasClosed { if !wasClosed {
c.conn.Close(code, reason) r.conn.Close(code, reason)
c.cancel(fmt.Errorf("doClose(): %s", reason)) r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
c.closeMutex.Lock() r.closeMutex.Lock()
close(c.closedNotify) close(r.closedNotify)
close(c.writeQueue) close(r.writeQueue)
c.closeMutex.Unlock() r.conn = nil
r.closeMutex.Unlock()
} }
} }

View File

@@ -14,6 +14,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
ws "github.com/coder/websocket"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
) )
@@ -26,7 +27,12 @@ type Relay struct {
URL string URL string
requestHeader http.Header // e.g. for origin header requestHeader http.Header // e.g. for origin header
connection *connection // websocket connection
conn *ws.Conn
writeQueue chan writeRequest
closed *atomic.Bool
closedNotify chan struct{}
Subscriptions *xsync.MapOf[int64, *Subscription] Subscriptions *xsync.MapOf[int64, *Subscription]
ConnectionError error ConnectionError error
@@ -98,7 +104,7 @@ func (r *Relay) String() string {
func (r *Relay) Context() context.Context { return r.connectionContext } func (r *Relay) Context() context.Context { return r.connectionContext }
// IsConnected returns true if the connection to this relay seems to be active. // IsConnected returns true if the connection to this relay seems to be active.
func (r *Relay) IsConnected() bool { return !r.connection.closed.Load() } func (r *Relay) IsConnected() bool { return !r.closed.Load() }
// Connect tries to establish a websocket connection to r.URL. // Connect tries to establish a websocket connection to r.URL.
// If the context expires before the connection is complete, an error is returned. // If the context expires before the connection is complete, an error is returned.
@@ -121,11 +127,9 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
return fmt.Errorf("invalid relay URL '%s'", r.URL) return fmt.Errorf("invalid relay URL '%s'", r.URL)
} }
conn, err := newConnection(ctx, r.connectionContextCancel, r.URL, r.handleMessage, r.requestHeader, tlsConfig) if err := r.newConnection(ctx, tlsConfig); err != nil {
if err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
} }
r.connection = conn
return nil return nil
} }
@@ -219,33 +223,33 @@ func (r *Relay) handleMessage(message string) {
// Write queues an arbitrary message to be sent to the relay. // Write queues an arbitrary message to be sent to the relay.
func (r *Relay) Write(msg []byte) { func (r *Relay) Write(msg []byte) {
r.connection.closeMutex.Lock() r.closeMutex.Lock()
defer r.connection.closeMutex.Unlock() defer r.closeMutex.Unlock()
select { select {
case <-r.connection.closedNotify: case <-r.closedNotify:
return return
default: default:
} }
select { select {
case <-r.connectionContext.Done(): case <-r.connectionContext.Done():
case r.connection.writeQueue <- writeRequest{msg: msg, answer: nil}: case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
} }
} }
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed). // WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
func (r *Relay) WriteWithError(msg []byte) error { func (r *Relay) WriteWithError(msg []byte) error {
ch := make(chan error) ch := make(chan error)
r.connection.closeMutex.Lock() r.closeMutex.Lock()
defer r.connection.closeMutex.Unlock() defer r.closeMutex.Unlock()
select { select {
case <-r.connection.closedNotify: case <-r.closedNotify:
return fmt.Errorf("failed to write to %s: <closed>", r.URL) return fmt.Errorf("failed to write to %s: <closed>", r.URL)
default: default:
} }
select { select {
case <-r.connectionContext.Done(): case <-r.connectionContext.Done():
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext)) return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
case r.connection.writeQueue <- writeRequest{msg: msg, answer: ch}: case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
} }
return <-ch return <-ch
} }
@@ -357,7 +361,7 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) { func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
sub := r.PrepareSubscription(ctx, filter, opts) sub := r.PrepareSubscription(ctx, filter, opts)
if r.connection == nil { if r.conn == nil {
return nil, fmt.Errorf("not connected to %s", r.URL) return nil, fmt.Errorf("not connected to %s", r.URL)
} }
@@ -367,7 +371,7 @@ func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionO
go func() { go func() {
select { select {
case <-r.connection.closedNotify: case <-r.closedNotify:
sub.unsub(ErrDisconnected) sub.unsub(ErrDisconnected)
case <-ctx.Done(): case <-ctx.Done():
} }
@@ -510,13 +514,13 @@ func (r *Relay) close(reason error) error {
if r.connectionContextCancel == nil { if r.connectionContextCancel == nil {
return fmt.Errorf("relay already closed") return fmt.Errorf("relay already closed")
} }
r.connectionContextCancel(reason)
r.connectionContextCancel = nil
if r.connection == nil { if r.conn == nil {
return fmt.Errorf("relay not connected") return fmt.Errorf("relay not connected")
} }
r.connectionContextCancel(reason)
return nil return nil
} }