diff --git a/connection.go b/connection.go index 609ce37..7901a66 100644 --- a/connection.go +++ b/connection.go @@ -7,8 +7,6 @@ import ( "errors" "fmt" "io" - "net/http" - "sync" "sync/atomic" "time" @@ -17,30 +15,13 @@ import ( var ErrDisconnected = errors.New("") -// Connection represents a websocket connection to a Nostr relay. -type connection struct { - conn *ws.Conn - cancel context.CancelCauseFunc - writeQueue chan writeRequest - closed *atomic.Bool - closedNotify chan struct{} - closeMutex sync.Mutex -} - type writeRequest struct { msg []byte answer chan error } -func newConnection( - ctx context.Context, - cancel context.CancelCauseFunc, - url string, - handleMessage func(string), - requestHeader http.Header, - tlsConfig *tls.Config, -) (*connection, error) { - debugLogf("{%s} connecting!\n", url) +func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error { + debugLogf("{%s} connecting!\n", r.URL) dialCtx := ctx if _, ok := dialCtx.Deadline(); !ok { @@ -48,9 +29,9 @@ func newConnection( dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) } - c, _, err := ws.Dial(dialCtx, url, getConnectionOptions(requestHeader, tlsConfig)) + c, _, err := ws.Dial(dialCtx, r.URL, getConnectionOptions(r.requestHeader, tlsConfig)) if err != nil { - return nil, err + return err } c.SetReadLimit(2 << 24) // 33MB @@ -63,40 +44,37 @@ func newConnection( writeQueue := make(chan writeRequest) readQueue := make(chan string) - conn := &connection{ - conn: c, - cancel: cancel, - writeQueue: writeQueue, - closed: &atomic.Bool{}, - closedNotify: make(chan struct{}), - } + r.conn = c + r.writeQueue = writeQueue + r.closed = &atomic.Bool{} + r.closedNotify = make(chan struct{}) go func() { for { select { case <-ctx.Done(): - conn.doClose(ws.StatusNormalClosure, "") - debugLogf("{%s} closing!, context done: '%s'\n", url, context.Cause(ctx)) + r.closeConnection(ws.StatusNormalClosure, "") + debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx)) return - case <-conn.closedNotify: + case <-r.closedNotify: return case <-ticker.C: ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long")) err := c.Ping(ctx) cancel() if err != nil { - debugLogf("{%s} closing!, ping failed: '%s'\n", url, err) - conn.doClose(ws.StatusAbnormalClosure, "ping took too long") + debugLogf("{%s} closing!, ping failed: '%s'\n", r.URL, err) + r.closeConnection(ws.StatusAbnormalClosure, "ping took too long") return } case wr := <-writeQueue: - debugLogf("{%s} sending '%v'\n", url, string(wr.msg)) + debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg)) ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long")) err := c.Write(ctx, ws.MessageText, wr.msg) cancel() if err != nil { - debugLogf("{%s} closing!, write failed: '%s'\n", url, err) - conn.doClose(ws.StatusAbnormalClosure, "write failed") + debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err) + r.closeConnection(ws.StatusAbnormalClosure, "write failed") if wr.answer != nil { wr.answer <- err } @@ -106,8 +84,8 @@ func newConnection( close(wr.answer) } case msg := <-readQueue: - debugLogf("{%s} received %v\n", url, msg) - handleMessage(msg) + debugLogf("{%s} received %v\n", r.URL, msg) + r.handleMessage(msg) } } }() @@ -121,13 +99,13 @@ func newConnection( _, reader, err := c.Reader(ctx) if err != nil { - debugLogf("{%s} closing!, reader failure: '%s'\n", url, err) - conn.doClose(ws.StatusAbnormalClosure, "failed to get reader") + debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err) + r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader") return } if _, err := io.Copy(buf, reader); err != nil { - debugLogf("{%s} closing!, read failure: '%s'\n", url, err) - conn.doClose(ws.StatusAbnormalClosure, "failed to read") + debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err) + r.closeConnection(ws.StatusAbnormalClosure, "failed to read") return } @@ -135,17 +113,18 @@ func newConnection( } }() - return conn, nil + return nil } -func (c *connection) doClose(code ws.StatusCode, reason string) { - wasClosed := c.closed.Swap(true) +func (r *Relay) closeConnection(code ws.StatusCode, reason string) { + wasClosed := r.closed.Swap(true) if !wasClosed { - c.conn.Close(code, reason) - c.cancel(fmt.Errorf("doClose(): %s", reason)) - c.closeMutex.Lock() - close(c.closedNotify) - close(c.writeQueue) - c.closeMutex.Unlock() + r.conn.Close(code, reason) + r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason)) + r.closeMutex.Lock() + close(r.closedNotify) + close(r.writeQueue) + r.conn = nil + r.closeMutex.Unlock() } } diff --git a/relay.go b/relay.go index c0efab7..f841744 100644 --- a/relay.go +++ b/relay.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + ws "github.com/coder/websocket" "github.com/puzpuzpuz/xsync/v3" ) @@ -26,7 +27,12 @@ type Relay struct { URL string requestHeader http.Header // e.g. for origin header - connection *connection + // websocket connection + conn *ws.Conn + writeQueue chan writeRequest + closed *atomic.Bool + closedNotify chan struct{} + Subscriptions *xsync.MapOf[int64, *Subscription] ConnectionError error @@ -98,7 +104,7 @@ func (r *Relay) String() string { func (r *Relay) Context() context.Context { return r.connectionContext } // IsConnected returns true if the connection to this relay seems to be active. -func (r *Relay) IsConnected() bool { return !r.connection.closed.Load() } +func (r *Relay) IsConnected() bool { return !r.closed.Load() } // Connect tries to establish a websocket connection to r.URL. // If the context expires before the connection is complete, an error is returned. @@ -121,11 +127,9 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error return fmt.Errorf("invalid relay URL '%s'", r.URL) } - conn, err := newConnection(ctx, r.connectionContextCancel, r.URL, r.handleMessage, r.requestHeader, tlsConfig) - if err != nil { + if err := r.newConnection(ctx, tlsConfig); err != nil { return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) } - r.connection = conn return nil } @@ -219,33 +223,33 @@ func (r *Relay) handleMessage(message string) { // Write queues an arbitrary message to be sent to the relay. func (r *Relay) Write(msg []byte) { - r.connection.closeMutex.Lock() - defer r.connection.closeMutex.Unlock() + r.closeMutex.Lock() + defer r.closeMutex.Unlock() select { - case <-r.connection.closedNotify: + case <-r.closedNotify: return default: } select { case <-r.connectionContext.Done(): - case r.connection.writeQueue <- writeRequest{msg: msg, answer: nil}: + case r.writeQueue <- writeRequest{msg: msg, answer: nil}: } } // WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed). func (r *Relay) WriteWithError(msg []byte) error { ch := make(chan error) - r.connection.closeMutex.Lock() - defer r.connection.closeMutex.Unlock() + r.closeMutex.Lock() + defer r.closeMutex.Unlock() select { - case <-r.connection.closedNotify: + case <-r.closedNotify: return fmt.Errorf("failed to write to %s: ", r.URL) default: } select { case <-r.connectionContext.Done(): return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext)) - case r.connection.writeQueue <- writeRequest{msg: msg, answer: ch}: + case r.writeQueue <- writeRequest{msg: msg, answer: ch}: } return <-ch } @@ -357,7 +361,7 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error { func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) { sub := r.PrepareSubscription(ctx, filter, opts) - if r.connection == nil { + if r.conn == nil { return nil, fmt.Errorf("not connected to %s", r.URL) } @@ -367,7 +371,7 @@ func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionO go func() { select { - case <-r.connection.closedNotify: + case <-r.closedNotify: sub.unsub(ErrDisconnected) case <-ctx.Done(): } @@ -510,13 +514,13 @@ func (r *Relay) close(reason error) error { if r.connectionContextCancel == nil { return fmt.Errorf("relay already closed") } - r.connectionContextCancel(reason) - r.connectionContextCancel = nil - if r.connection == nil { + if r.conn == nil { return fmt.Errorf("relay not connected") } + r.connectionContextCancel(reason) + return nil }